IIS_2023_1/madyshev_egor_lab_7/main.py

66 lines
2.5 KiB
Python
Raw Normal View History

2023-11-02 19:12:30 +04:00
import numpy as np
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.models import Sequential
from keras.layers import LSTM, Dense, Embedding
# Чтение текста из файла
with open('mumu.txt', 'r', encoding='utf-8') as file:
text = file.read()
# Параметры модели
seq_length = 50 # Длина входных последовательностей
num_epochs = 50
gen_length = 200 # Длина генерируемого текста
seed_text = "Начнем с этого" # Начальная фраза для генерации
# Создание экземпляра Tokenizer и обучение на тексте
tokenizer = Tokenizer()
tokenizer.fit_on_texts([text])
vocab_size = len(tokenizer.word_index) + 1 # Размер словаря
# Преобразование текста в последовательности чисел
sequences = tokenizer.texts_to_sequences([text])[0]
# Создание входных и выходных последовательностей
X_data = []
y_data = []
for i in range(seq_length, len(sequences)):
sequence = sequences[i - seq_length:i]
target = sequences[i]
X_data.append(sequence)
y_data.append(target)
X = np.array(X_data)
y = np.array(y_data)
# Создание модели RNN
model = Sequential()
model.add(Embedding(input_dim=vocab_size, output_dim=128, input_length=seq_length))
model.add(LSTM(256, return_sequences=True))
model.add(LSTM(256))
model.add(Dense(vocab_size, activation='softmax'))
# Компиляция модели
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# Обучение модели
model.fit(X, y, epochs=num_epochs, batch_size=64, verbose=1)
# Функция для генерации текста
def generate_text(seed_text, gen_length):
generated_text = seed_text
for _ in range(gen_length):
sequence = tokenizer.texts_to_sequences([seed_text])[0]
sequence = pad_sequences([sequence], maxlen=seq_length)
prediction = model.predict(sequence, verbose=0)
predicted_index = np.argmax(prediction)
predicted_word = [word for word, index in tokenizer.word_index.items() if index == predicted_index][0]
generated_text += " " + predicted_word
seed_text += " " + predicted_word
return generated_text
# Генерация текста
generated_text = generate_text(seed_text, gen_length)
print(generated_text)