IIS_2023_1/malkova_anastasia_lab_7/generation.py

39 lines
1.4 KiB
Python
Raw Normal View History

2023-11-17 01:50:49 +04:00
import torch
import torch.nn.functional as F
import numpy as np
from config import BATCH_SIZE, SEQ_LEN, PREDICTION_LEN
def get_batch(sequence):
trains = []
targets = []
for _ in range(BATCH_SIZE):
batch_start = np.random.randint(0, len(sequence) - SEQ_LEN)
chunk = sequence[batch_start: batch_start + SEQ_LEN]
train = torch.LongTensor(chunk[:-1]).view(-1, 1)
target = torch.LongTensor(chunk[1:]).view(-1, 1)
trains.append(train)
targets.append(target)
return torch.stack(trains, dim=0), torch.stack(targets, dim=0)
def evaluate(model, char_to_idx, idx_to_char, device, start_text=' ', prediction_len=PREDICTION_LEN, temp=0.3):
hidden = model.init_hidden()
idx_input = [char_to_idx[char] for char in start_text]
train = torch.LongTensor(idx_input).view(-1, 1, 1).to(device)
predicted_text = start_text
_, hidden = model(train, hidden)
inp = train[-1].view(-1, 1, 1)
for i in range(prediction_len):
output, hidden = model(inp.to(device), hidden)
output_logits = output.cpu().data.view(-1)
p_next = F.softmax(output_logits / temp, dim=-1).detach().cpu().data.numpy()
top_index = np.random.choice(len(char_to_idx), p=p_next)
inp = torch.LongTensor([top_index]).view(-1, 1, 1).to(device)
predicted_char = idx_to_char[top_index]
predicted_text += predicted_char
return predicted_text