61 lines
1.7 KiB
Python
61 lines
1.7 KiB
Python
|
import numpy as np
|
||
|
from keras.models import Sequential
|
||
|
from keras.layers import LSTM, Dense
|
||
|
|
||
|
# Загрузка текстового файла
|
||
|
file_path = "A.txt"
|
||
|
|
||
|
with open(file_path, "r", encoding="utf-8") as file:
|
||
|
text = file.read()
|
||
|
|
||
|
# Предобработка данных
|
||
|
chars = sorted(list(set(text)))
|
||
|
char_indices = {char: i for i, char in enumerate(chars)}
|
||
|
indices_char = {i: char for i, char in enumerate(chars)}
|
||
|
|
||
|
maxlen = 40
|
||
|
step = 3
|
||
|
sentences = []
|
||
|
next_chars = []
|
||
|
|
||
|
for i in range(0, len(text) - maxlen, step):
|
||
|
sentences.append(text[i : i + maxlen])
|
||
|
next_chars.append(text[i + maxlen])
|
||
|
|
||
|
x = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.uint8)
|
||
|
y = np.zeros((len(sentences), len(chars)), dtype=np.uint8)
|
||
|
|
||
|
for i, sentence in enumerate(sentences):
|
||
|
for t, char in enumerate(sentence):
|
||
|
x[i, t, char_indices[char]] = 1
|
||
|
y[i, char_indices[next_chars[i]]] = 1
|
||
|
|
||
|
# Определение модели RNN
|
||
|
model = Sequential()
|
||
|
model.add(LSTM(128, input_shape=(maxlen, len(chars))))
|
||
|
model.add(Dense(len(chars), activation="softmax"))
|
||
|
|
||
|
model.compile(loss="categorical_crossentropy", optimizer="adam")
|
||
|
|
||
|
# Обучение модели
|
||
|
model.fit(x, y, batch_size=128, epochs=20)
|
||
|
|
||
|
# Генерация текста
|
||
|
start_index = np.random.randint(0, len(text) - maxlen - 1)
|
||
|
seed_text = text[start_index : start_index + maxlen]
|
||
|
|
||
|
generated_text = seed_text
|
||
|
for i in range(400):
|
||
|
x_pred = np.zeros((1, maxlen, len(chars)))
|
||
|
for t, char in enumerate(seed_text):
|
||
|
x_pred[0, t, char_indices[char]] = 1
|
||
|
|
||
|
preds = model.predict(x_pred, verbose=0)[0]
|
||
|
next_index = np.argmax(preds)
|
||
|
next_char = indices_char[next_index]
|
||
|
|
||
|
generated_text += next_char
|
||
|
seed_text = seed_text[1:] + next_char
|
||
|
|
||
|
print(generated_text)
|