IIS_2023_1/lipatov_ilya_lab_7/lab7.py

76 lines
2.1 KiB
Python
Raw Permalink Normal View History

2023-12-01 22:38:00 +04:00
from tensorflow.python.keras.utils.np_utils import to_categorical
from keras.layers import Dense, Dropout, LSTM
from keras.models import Sequential
import numpy as np
N_EPOCHS = 64
BATCH_SIZE = 128
N_UNITS = 256
DROPOUT_RATE = 0.2
SEQ_LEN = 100
PREDICTION_LEN = 100
def prepare_text(text, ids, seq_length):
dataX = []
dataY = []
for i in range(0, len(text) - seq_length, 1):
seq_in = text[i:i + seq_length]
seq_out = text[i + seq_length]
dataX.append([ids[char] for char in seq_in])
dataY.append(ids[seq_out])
return dataX, dataY
def generate_text(model, pattern, chars, n_vocab, prediction_len):
text = ""
for i in range(prediction_len):
x = np.reshape(pattern, (1, len(pattern), 1)) / float(n_vocab)
prediction = model.predict(x, verbose=0)
index = np.argmax(prediction)
text += chars[index]
pattern.append(index)
pattern = pattern[1:len(pattern)]
return text
class Model(Sequential):
def __init__(self, x, y):
super().__init__()
self.add(LSTM(N_UNITS, input_shape=(x.shape[1], x.shape[2]), return_sequences=True))
self.add(Dropout(DROPOUT_RATE))
self.add(LSTM(N_UNITS))
self.add(Dropout(DROPOUT_RATE))
self.add(Dense(y.shape[1], activation='softmax'))
def compile_model(self):
self.compile(loss='categorical_crossentropy', optimizer='adam')
def fit_model(self, x, y, batch_size, epochs):
self.fit(x, y, batch_size, epochs)
text = open('warandpeace.txt', 'r', encoding='UTF-8').read().lower()
vocab = sorted(set(text))
n_vocab = len(vocab)
ids = dict((c, i) for i, c in enumerate(vocab))
chars = dict((i, c) for i, c in enumerate(vocab))
dataX, dataY = prepare_text(text=text, ids=ids, seq_length=SEQ_LEN)
n_patterns = len(dataX)
X = np.reshape(dataX, (n_patterns, SEQ_LEN, 1))
X = X / float(n_vocab)
y = to_categorical(dataY)
model = Model(X, y)
model.compile_model()
model.fit_model(X, y, BATCH_SIZE, N_EPOCHS)
pattern = dataX[np.random.randint(0, len(dataX) - 1)]
prediction = generate_text(model, pattern, chars, n_vocab, PREDICTION_LEN)
print("\nPREDICTION: ", prediction)