.. | ||
generate.py | ||
generated_text.png | ||
model.py | ||
README.md | ||
single-char.pth | ||
train_process.png | ||
train.py | ||
wonderland.txt |
Лабораторная работа №7. Вариант 21
Тема
Рекуррентная нейронная сеть и задача генерации текста
Задание
-
Выбрать художественный текст и обучить на нем рекуррентную нейронную сеть для решения задачи генерации.
-
Подобрать архитектуру и параметры так, чтобы приблизиться к максимально осмысленному результату.
Используемые ресурсы
1. Художественный текст на английском языке wonderland.txt
2. Python-скрипты: generate.py
, model.py
, train.py
.
Описание работы
Подготовка данных:
В файле train.py
реализована функция get_data
, которая загружает художественный текст, приводит его к нижнему регистру, и создает сопоставление символов числовым значениям.
Текст разбивается на последовательности фиксированной длины seq_length
, и каждая последовательность связывается с символом, следующим за ней.
Данные приводятся к тензорам PyTorch и нормализуются для обучения модели.
Архитектура модели:
В файле model.py
определен класс CharModel
, наследуемый от nn.Module
и представляющий собой рекуррентную нейронную сеть.
Архитектура модели включает в себя один слой LSTM с размером скрытого состояния 256, слой dropout для регуляризации и линейный слой для вывода результатов.
Обучение модели:
В файле train.py
реализован скрипт для обучения модели. Выбрана оптимизация Adam, функция потерь - CrossEntropyLoss
.
Обучение происходит на GPU, если он доступен. Обучение проводится в течение нескольких эпох, с валидацией на каждой эпохе. Сохраняется лучшая модель.
Процесс обучения модели:
Генерация текста:
В файле generate.py
модель загружается из сохраненного состояния. Генерируется случайный промпт из исходного текста, и модель используется для предсказания следующего символа в цикле.
Вывод:
В сгенерированном тексте можно найти осмысленные участки, поэтому можно сделать вывод, что модель действительно хорошо обучилась.