IIS_2023_1/almukhammetov_bulat_lab_7/lab7.py
2023-12-02 11:45:17 +04:00

90 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Выбрать художественный текст (четные варианты русскоязычный, нечетные англоязычный) и обучить на нем
# рекуррентную нейронную сеть для решения задачи генерации. Подобрать архитектуру и параметры так, чтобы приблизиться
# к максимально осмысленному результату. Далее разбиться на пары четный-нечетный вариант, обменяться разработанными
# сетями и проверить, как архитектура товарища справляется с вашим текстом. В завершении подобрать компромиссную
# архитектуру, справляющуюся достаточно хорошо с обоими видами текстов.
import numpy as np
from keras.models import Sequential
from keras.layers import Embedding, LSTM, Dense
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
def load_text(file_path):
with open(file_path, encoding='utf-8') as file:
return file.read()
def create_tokenizer(text):
tokenizer = Tokenizer()
tokenizer.fit_on_texts([text])
return tokenizer
def generate_input_sequences(text, tokenizer):
input_sequences = []
for line in text.split('\n'):
token_list = tokenizer.texts_to_sequences([line])[0]
for i in range(1, len(token_list)):
n_gram_sequence = token_list[:i + 1]
input_sequences.append(n_gram_sequence)
max_sequence_length = max([len(x) for x in input_sequences])
input_sequences = pad_sequences(input_sequences, maxlen=max_sequence_length, padding='pre')
predictors, labels = input_sequences[:, :-1], input_sequences[:, -1]
return predictors, labels, max_sequence_length
def create_model(total_words, max_sequence_length):
model = Sequential()
model.add(Embedding(total_words, 100, input_length=max_sequence_length - 1))
model.add(LSTM(150))
model.add(Dense(total_words, activation='softmax'))
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
def train_model(model, predictors, labels, epochs):
model.fit(predictors, labels, epochs=epochs, verbose=1)
def generate_text(seed_text, next_words, model, tokenizer, max_sequence_length):
for _ in range(next_words):
token_list = tokenizer.texts_to_sequences([seed_text])[0]
token_list = pad_sequences([token_list], maxlen=max_sequence_length - 1, padding='pre')
predicted = np.argmax(model.predict(token_list), axis=-1)
output_word = ""
for word, index in tokenizer.word_index.items():
if index == predicted:
output_word = word
break
seed_text += " " + output_word
return seed_text
# Загрузка текста
#file_path = 'russian_text.txt'
file_path = 'english_text.txt'
text = load_text(file_path)
# Создание токенизатора
tokenizer = create_tokenizer(text)
total_words = len(tokenizer.word_index) + 1
# Генерация входных последовательностей
predictors, labels, max_sequence_length = generate_input_sequences(text, tokenizer)
# Создание модели
model = create_model(total_words, max_sequence_length)
# Тренировка модели
train_model(model, predictors, labels, epochs=150)
# Генерация текста
seed_text = "Old man"
next_words = 50
generated_text = generate_text(seed_text, next_words, model, tokenizer, max_sequence_length)
print(generated_text)