IIS_2023_1/mashkova_margarita_lab_7/main.py

81 lines
2.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
from keras.models import Sequential
from keras.layers import LSTM, Dense, Embedding
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
# filename = "russian_text.txt"
filename = "english_text.txt"
# Чтение текста из файла
with open(filename, "r", encoding="utf-8") as f:
text = f.read()
# Создание токенизатора
tokenizer = Tokenizer()
# Создает словарь вида (слово - индекс) на основе частоты использования слов
# чем меньше индекс, тем чаще встречается слово
# т.е. каждое слово принимает целочисленное значение
tokenizer.fit_on_texts([text])
# print("Словарь:")
# print(tokenizer.word_index)
# Преобразование текста в последовательность чисел
sequences = tokenizer.texts_to_sequences([text])[0]
vocab_size = len(tokenizer.word_index) + 1
# print("\nЗакодированный текст:")
# print(sequences)
# Длина входных последовательностей
seq_length = 5
# Создание входных и выходных последовательностей
X_data = []
y_data = []
for i in range(seq_length, len(sequences)):
sequence = sequences[i - seq_length:i]
target = sequences[i]
X_data.append(sequence)
y_data.append(target)
X = np.array(X_data)
y = np.array(y_data)
# Создание модели
model = Sequential()
model.add(Embedding(input_dim=vocab_size, output_dim=128, input_length=seq_length))
model.add(LSTM(256, return_sequences=True))
model.add(LSTM(256))
model.add(Dense(vocab_size, activation='softmax'))
# Компиляция модели
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# Обучение модели
model.fit(X, y, epochs=500, batch_size=64, verbose=1)
# Начальная фраза для генерации
seed_text = "an old woman"
# Длина генерируемого текста
gen_length = 100
# Функция для генерации текста
def generate_text(seed_text, gen_length):
generated_text = seed_text
for _ in range(gen_length):
sequence = tokenizer.texts_to_sequences([seed_text])[0]
sequence = pad_sequences([sequence], maxlen=seq_length)
prediction = model.predict(sequence, verbose=0)
predicted_index = np.argmax(prediction)
predicted_word = [word for word, index in tokenizer.word_index.items() if index == predicted_index][0]
generated_text += " " + predicted_word
seed_text += " " + predicted_word
return generated_text
# Генерация текста
generated_text = generate_text(seed_text, gen_length)
print(generated_text)