IIS_2023_1/malkova_anastasia_lab_7/train.py

58 lines
2.0 KiB
Python
Raw Normal View History

2023-11-17 01:50:49 +04:00
import torch
import torch.nn as nn
import numpy as np
from text_rnn import TextRNN
from config import BATCH_SIZE, N_EPOCHS, LOSS_AVG_MAX, HIDDEN_SIZE, EMBEDDING_SIZE, N_LAYERS
from generation import get_batch, evaluate
def create_parameters(idx_to_char):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = TextRNN(input_size=len(idx_to_char), hidden_size=HIDDEN_SIZE, embedding_size=EMBEDDING_SIZE, n_layers=N_LAYERS, device=device)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, amsgrad=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
patience=5,
verbose=True,
factor=0.5
)
n_epochs = N_EPOCHS
loss_avg = []
return criterion, scheduler, n_epochs, loss_avg, device, model, optimizer
def check_loss(loss_avg, scheduler, model, char_to_idx, idx_to_char, device):
if len(loss_avg) >= LOSS_AVG_MAX:
mean_loss = np.mean(loss_avg)
print(f'Loss: {mean_loss}')
scheduler.step(mean_loss)
loss_avg = []
model.eval()
predicted_text = evaluate(model, char_to_idx, idx_to_char, device=device)
print(predicted_text)
return loss_avg
def training(n_epochs, model, sequence, device, criterion, optimizer, loss_avg, scheduler, char_to_idx, idx_to_char):
for epoch in range(n_epochs):
model.train()
train, target = get_batch(sequence)
train = train.permute(1, 0, 2).to(device)
target = target.permute(1, 0, 2).to(device)
hidden = model.init_hidden(BATCH_SIZE)
output, hidden = model(train, hidden)
loss = criterion(output.permute(1, 2, 0), target.squeeze(-1).permute(1, 0))
loss.backward()
optimizer.step()
optimizer.zero_grad()
loss_avg.append(loss.item())
loss_avg = check_loss(loss_avg, scheduler, model, char_to_idx, idx_to_char, device)