IIS_2023_1/kondrashin_mikhail_lab_7/main.py

37 lines
951 B
Python
Raw Normal View History

2023-11-27 01:52:58 +04:00
from tensorflow.python.keras.utils.np_utils import to_categorical
from utils import SEQ_LEN, N_EPOCHS, BATCH_SIZE, PREDICTION_LEN
from data import prepare_text, generate_text, read_text, peek_random_seed
from model import Model
import numpy as np
if __name__ == '__main__':
text = read_text('en.txt')
vocab = sorted(set(text))
n_vocab = len(vocab)
ids = dict((c, i) for i, c in enumerate(vocab))
chars = dict((i, c) for i, c in enumerate(vocab))
dataX, dataY = prepare_text(text=text, ids=ids, seq_length=SEQ_LEN)
n_patterns = len(dataX)
X = np.reshape(dataX, (n_patterns, SEQ_LEN, 1))
X = X / float(n_vocab)
y = to_categorical(dataY)
model = Model(X, y)
model.compile_model()
model.fit_model(X, y, BATCH_SIZE, N_EPOCHS)
pattern = dataX[peek_random_seed(dataX)]
prediction = generate_text(model, pattern, chars, n_vocab, PREDICTION_LEN)
print("\nPREDICTION: ", prediction)