37 lines
951 B
Python
37 lines
951 B
Python
|
from tensorflow.python.keras.utils.np_utils import to_categorical
|
||
|
from utils import SEQ_LEN, N_EPOCHS, BATCH_SIZE, PREDICTION_LEN
|
||
|
from data import prepare_text, generate_text, read_text, peek_random_seed
|
||
|
from model import Model
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
text = read_text('en.txt')
|
||
|
|
||
|
vocab = sorted(set(text))
|
||
|
|
||
|
n_vocab = len(vocab)
|
||
|
|
||
|
ids = dict((c, i) for i, c in enumerate(vocab))
|
||
|
chars = dict((i, c) for i, c in enumerate(vocab))
|
||
|
|
||
|
dataX, dataY = prepare_text(text=text, ids=ids, seq_length=SEQ_LEN)
|
||
|
|
||
|
n_patterns = len(dataX)
|
||
|
|
||
|
X = np.reshape(dataX, (n_patterns, SEQ_LEN, 1))
|
||
|
|
||
|
X = X / float(n_vocab)
|
||
|
|
||
|
y = to_categorical(dataY)
|
||
|
|
||
|
model = Model(X, y)
|
||
|
model.compile_model()
|
||
|
model.fit_model(X, y, BATCH_SIZE, N_EPOCHS)
|
||
|
|
||
|
pattern = dataX[peek_random_seed(dataX)]
|
||
|
|
||
|
prediction = generate_text(model, pattern, chars, n_vocab, PREDICTION_LEN)
|
||
|
|
||
|
print("\nPREDICTION: ", prediction)
|