IIS_2023_1/istyukov_timofey_lab_7/lab7.py
2024-01-13 22:21:26 +04:00

127 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Выбрать художественный текст на языке по варианту и обучить на нём рекуррентную нейронную сеть для решения задачи
генерации. Подобрать архитектуру и параметры так, чтобы приблизиться к максимально осмысленному результату.
Далее разбиться на пары чётный-нечётный вариант, обменяться разработанными сетями и проверить, как архитектура товарища
справляется с вашим текстом.
В завершении подобрать компромиссную архитектуру, справляющуюся достаточно хорошо с обоими видами текстов.
"""
# 12 вариант
# Вариант языка текста: русский
# Художественный текст: Книга "Ф.М. Достоевский — Преступление и наказание"
import os
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Dropout, LSTM
from keras.utils import to_categorical
FILE_NAME = 'belye-nochi.txt'
# Открытие файла
df_text = (open(FILE_NAME, encoding='utf-8').read())
"""""""""""""""""""""""""""""""""
ПРЕДВАРИТЕЛЬНАЯ ОБРАБОТКА ДАННЫХ
"""""""""""""""""""""""""""""""""
# Перевод всех символов в нижний регистр для упрощения обучения
df_text = df_text.lower()
# Формирование набора символов на основе текста
characters = sorted(list(set(df_text)))
print("\033[92m\n---> Итого символов: \033[00m", len(characters))
# Сопоставления символов к номеру
char_to_n = {char: n for n, char in enumerate(characters)}
# Массивы
X = [] # обучающий
Y = [] # целевой
# Длина исходного текста
length = len(df_text)
# Длина последовательности символов для предсказания конкретного символа
seq_length = 5
# Перебора полного текста
for i in range(0, length - seq_length, 1):
sequence = df_text[i:i + seq_length]
label = df_text[i + seq_length]
X.append([char_to_n[char] for char in sequence])
Y.append(char_to_n[label])
# Масштабирование целых чисел в диапазон от 0 до 1 для облегчения изучения шаблонов сетью
X_modified = np.reshape(X, (len(X), seq_length, 1))
X_modified = X_modified / float(len(characters))
Y_modified = to_categorical(Y)
"""""""""""""""""""""
ПОСТРОЕНИЕ МОДЕЛИ
"""""""""""""""""""""
# Инициализация модели
model = Sequential()
# Пополнение модели атрибутами
model.add(LSTM(700, input_shape=(X_modified.shape[1], X_modified.shape[2]), return_sequences=True)) # первый слой на 700 единиц с входной формой
model.add(Dropout(0.2)) # кик нейронов с вероятностью 20%
model.add(LSTM(700, return_sequences=True)) # второй слой на 700 единиц, обрабатывающий те же последовательности
model.add(Dropout(0.2))
model.add(LSTM(700)) # третий слой на 700 единиц
model.add(Dropout(0.2))
model.add(Dense(Y_modified.shape[1], activation='softmax')) # сеть с плотным слоем для вывода символов
# Конфигурация модели с вычислением категориальных потерь кроссэнтропии
model.compile(loss='categorical_crossentropy', optimizer='adam')
# Обучение модели, если сохранённая модель в текущей папке отсутствует
if not os.path.exists('save_text_generator_deeper_model.h5'):
# Обучение модели на 50 эпохах и 100 обучающих примерах за один проход
model.fit(X_modified, Y_modified, epochs=50, batch_size=100)
# Сохранение обученной модели в файл в текущей папке
model.save_weights('save_text_generator_deeper_model.h5')
"""""""""""""""""""""""""""""""""""""""""""""
Генерация текста
"""""""""""""""""""""""""""""""""""""""""""""
# Загрузка обученной модели с текущей папки
model.load_weights('save_text_generator_deeper_model.h5')
# Сопоставления номеров обратно к символам
n_to_char = dict((i, c) for i, c in enumerate(characters))
# Выбор случайной точки старта в тексте для генерации
start = np.random.randint(0, len(X) - 1)
# Последовательность этой точки
pattern = X[start]
txtxt = "" # строка результата
# сохранение старта в результат
for value in pattern:
txtxt += n_to_char[value]
print("\033[92m\n---> Точка старта: \033[00m", txtxt)
# Генерация 200 символов
for i in range(200):
# Масштабирование последовательности символов
x = np.reshape(pattern, (1, len(pattern), 1))
x = x / float((len(characters)))
prediction = model.predict(x, verbose=0) # прогноз вероятностей к каждому символу
index = np.argmax(prediction) # выбор индекса лучшего по вероятности
txtxt += n_to_char[index] # запись символа с таким индексом в результат
# сохранение индекса символа в конечную результирующую последовательность
pattern.append(index)
pattern = pattern[1:len(pattern)]
print("\033[92m\n[----------> Результат <----------]\033[00m")
print(txtxt)