58 lines
2.0 KiB
Python
58 lines
2.0 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import numpy as np
|
||
|
|
||
|
from text_rnn import TextRNN
|
||
|
from config import BATCH_SIZE, N_EPOCHS, LOSS_AVG_MAX, HIDDEN_SIZE, EMBEDDING_SIZE, N_LAYERS
|
||
|
from generation import get_batch, evaluate
|
||
|
|
||
|
|
||
|
def create_parameters(idx_to_char):
|
||
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||
|
model = TextRNN(input_size=len(idx_to_char), hidden_size=HIDDEN_SIZE, embedding_size=EMBEDDING_SIZE, n_layers=N_LAYERS, device=device)
|
||
|
model.to(device)
|
||
|
|
||
|
criterion = nn.CrossEntropyLoss()
|
||
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, amsgrad=True)
|
||
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||
|
optimizer,
|
||
|
patience=5,
|
||
|
verbose=True,
|
||
|
factor=0.5
|
||
|
)
|
||
|
|
||
|
n_epochs = N_EPOCHS
|
||
|
loss_avg = []
|
||
|
|
||
|
return criterion, scheduler, n_epochs, loss_avg, device, model, optimizer
|
||
|
|
||
|
|
||
|
def check_loss(loss_avg, scheduler, model, char_to_idx, idx_to_char, device):
|
||
|
if len(loss_avg) >= LOSS_AVG_MAX:
|
||
|
mean_loss = np.mean(loss_avg)
|
||
|
print(f'Loss: {mean_loss}')
|
||
|
scheduler.step(mean_loss)
|
||
|
loss_avg = []
|
||
|
model.eval()
|
||
|
predicted_text = evaluate(model, char_to_idx, idx_to_char, device=device)
|
||
|
print(predicted_text)
|
||
|
return loss_avg
|
||
|
|
||
|
|
||
|
def training(n_epochs, model, sequence, device, criterion, optimizer, loss_avg, scheduler, char_to_idx, idx_to_char):
|
||
|
for epoch in range(n_epochs):
|
||
|
model.train()
|
||
|
train, target = get_batch(sequence)
|
||
|
train = train.permute(1, 0, 2).to(device)
|
||
|
target = target.permute(1, 0, 2).to(device)
|
||
|
hidden = model.init_hidden(BATCH_SIZE)
|
||
|
|
||
|
output, hidden = model(train, hidden)
|
||
|
loss = criterion(output.permute(1, 2, 0), target.squeeze(-1).permute(1, 0))
|
||
|
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
optimizer.zero_grad()
|
||
|
|
||
|
loss_avg.append(loss.item())
|
||
|
loss_avg = check_loss(loss_avg, scheduler, model, char_to_idx, idx_to_char, device)
|