53 lines
3.3 KiB
Markdown
53 lines
3.3 KiB
Markdown
|
# Лабораторная работа №7. Вариант 21
|
|||
|
|
|||
|
## Тема
|
|||
|
|
|||
|
Рекуррентная нейронная сеть и задача генерации текста
|
|||
|
|
|||
|
## Задание
|
|||
|
|
|||
|
- Выбрать художественный текст и обучить на нем рекуррентную нейронную сеть для решения задачи генерации.
|
|||
|
|
|||
|
- Подобрать архитектуру и параметры так, чтобы приблизиться к максимально осмысленному результату.
|
|||
|
|
|||
|
## Используемые ресурсы
|
|||
|
|
|||
|
1. Художественный текст на английском языке ```wonderland.txt```
|
|||
|
|
|||
|
2. Python-скрипты: ```generate.py```, ```model.py```, ```train.py```.
|
|||
|
|
|||
|
## Описание работы
|
|||
|
|
|||
|
### Подготовка данных:
|
|||
|
В файле ```train.py``` реализована функция ```get_data```, которая загружает художественный текст, приводит его к нижнему регистру, и создает сопоставление символов числовым значениям.
|
|||
|
|
|||
|
Текст разбивается на последовательности фиксированной длины ```seq_length```, и каждая последовательность связывается с символом, следующим за ней.
|
|||
|
|
|||
|
Данные приводятся к тензорам PyTorch и нормализуются для обучения модели.
|
|||
|
|
|||
|
### Архитектура модели:
|
|||
|
|
|||
|
В файле ```model.py``` определен класс ```CharModel```, наследуемый от ```nn.Module``` и представляющий собой рекуррентную нейронную сеть.
|
|||
|
|
|||
|
Архитектура модели включает в себя один слой LSTM с размером скрытого состояния 256, слой dropout для регуляризации и линейный слой для вывода результатов.
|
|||
|
|
|||
|
### Обучение модели:
|
|||
|
|
|||
|
В файле ```train.py``` реализован скрипт для обучения модели. Выбрана оптимизация Adam, функция потерь - ```CrossEntropyLoss```.
|
|||
|
|
|||
|
Обучение происходит на GPU, если он доступен. Обучение проводится в течение нескольких эпох, с валидацией на каждой эпохе. Сохраняется лучшая модель.
|
|||
|
|
|||
|
Процесс обучения модели:
|
|||
|
|
|||
|
![](train_process.png "")
|
|||
|
|
|||
|
### Генерация текста:
|
|||
|
|
|||
|
В файле ```generate.py``` модель загружается из сохраненного состояния. Генерируется случайный промпт из исходного текста, и модель используется для предсказания следующего символа в цикле.
|
|||
|
|
|||
|
## Вывод:
|
|||
|
|
|||
|
![](generated_text.png "")
|
|||
|
|
|||
|
В сгенерированном тексте можно найти осмысленные участки, поэтому можно сделать вывод, что модель действительно хорошо обучилась.
|