76 lines
2.1 KiB
Python
76 lines
2.1 KiB
Python
|
from tensorflow.python.keras.utils.np_utils import to_categorical
|
||
|
from keras.layers import Dense, Dropout, LSTM
|
||
|
from keras.models import Sequential
|
||
|
import numpy as np
|
||
|
|
||
|
N_EPOCHS = 64
|
||
|
BATCH_SIZE = 128
|
||
|
N_UNITS = 256
|
||
|
DROPOUT_RATE = 0.2
|
||
|
SEQ_LEN = 100
|
||
|
PREDICTION_LEN = 100
|
||
|
|
||
|
|
||
|
def prepare_text(text, ids, seq_length):
|
||
|
dataX = []
|
||
|
dataY = []
|
||
|
|
||
|
for i in range(0, len(text) - seq_length, 1):
|
||
|
seq_in = text[i:i + seq_length]
|
||
|
seq_out = text[i + seq_length]
|
||
|
|
||
|
dataX.append([ids[char] for char in seq_in])
|
||
|
dataY.append(ids[seq_out])
|
||
|
|
||
|
return dataX, dataY
|
||
|
|
||
|
|
||
|
def generate_text(model, pattern, chars, n_vocab, prediction_len):
|
||
|
text = ""
|
||
|
|
||
|
for i in range(prediction_len):
|
||
|
x = np.reshape(pattern, (1, len(pattern), 1)) / float(n_vocab)
|
||
|
prediction = model.predict(x, verbose=0)
|
||
|
index = np.argmax(prediction)
|
||
|
|
||
|
text += chars[index]
|
||
|
|
||
|
pattern.append(index)
|
||
|
pattern = pattern[1:len(pattern)]
|
||
|
|
||
|
return text
|
||
|
|
||
|
class Model(Sequential):
|
||
|
|
||
|
def __init__(self, x, y):
|
||
|
super().__init__()
|
||
|
self.add(LSTM(N_UNITS, input_shape=(x.shape[1], x.shape[2]), return_sequences=True))
|
||
|
self.add(Dropout(DROPOUT_RATE))
|
||
|
self.add(LSTM(N_UNITS))
|
||
|
self.add(Dropout(DROPOUT_RATE))
|
||
|
self.add(Dense(y.shape[1], activation='softmax'))
|
||
|
|
||
|
def compile_model(self):
|
||
|
self.compile(loss='categorical_crossentropy', optimizer='adam')
|
||
|
|
||
|
def fit_model(self, x, y, batch_size, epochs):
|
||
|
self.fit(x, y, batch_size, epochs)
|
||
|
|
||
|
|
||
|
text = open('warandpeace.txt', 'r', encoding='UTF-8').read().lower()
|
||
|
vocab = sorted(set(text))
|
||
|
n_vocab = len(vocab)
|
||
|
ids = dict((c, i) for i, c in enumerate(vocab))
|
||
|
chars = dict((i, c) for i, c in enumerate(vocab))
|
||
|
dataX, dataY = prepare_text(text=text, ids=ids, seq_length=SEQ_LEN)
|
||
|
n_patterns = len(dataX)
|
||
|
X = np.reshape(dataX, (n_patterns, SEQ_LEN, 1))
|
||
|
X = X / float(n_vocab)
|
||
|
y = to_categorical(dataY)
|
||
|
model = Model(X, y)
|
||
|
model.compile_model()
|
||
|
model.fit_model(X, y, BATCH_SIZE, N_EPOCHS)
|
||
|
pattern = dataX[np.random.randint(0, len(dataX) - 1)]
|
||
|
prediction = generate_text(model, pattern, chars, n_vocab, PREDICTION_LEN)
|
||
|
print("\nPREDICTION: ", prediction)
|