39 lines
1.4 KiB
Python
39 lines
1.4 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from config import BATCH_SIZE, SEQ_LEN, PREDICTION_LEN
|
|
|
|
|
|
def get_batch(sequence):
|
|
trains = []
|
|
targets = []
|
|
for _ in range(BATCH_SIZE):
|
|
batch_start = np.random.randint(0, len(sequence) - SEQ_LEN)
|
|
chunk = sequence[batch_start: batch_start + SEQ_LEN]
|
|
train = torch.LongTensor(chunk[:-1]).view(-1, 1)
|
|
target = torch.LongTensor(chunk[1:]).view(-1, 1)
|
|
trains.append(train)
|
|
targets.append(target)
|
|
return torch.stack(trains, dim=0), torch.stack(targets, dim=0)
|
|
|
|
|
|
def evaluate(model, char_to_idx, idx_to_char, device, start_text=' ', prediction_len=PREDICTION_LEN, temp=0.3):
|
|
hidden = model.init_hidden()
|
|
idx_input = [char_to_idx[char] for char in start_text]
|
|
train = torch.LongTensor(idx_input).view(-1, 1, 1).to(device)
|
|
predicted_text = start_text
|
|
|
|
_, hidden = model(train, hidden)
|
|
|
|
inp = train[-1].view(-1, 1, 1)
|
|
|
|
for i in range(prediction_len):
|
|
output, hidden = model(inp.to(device), hidden)
|
|
output_logits = output.cpu().data.view(-1)
|
|
p_next = F.softmax(output_logits / temp, dim=-1).detach().cpu().data.numpy()
|
|
top_index = np.random.choice(len(char_to_idx), p=p_next)
|
|
inp = torch.LongTensor([top_index]).view(-1, 1, 1).to(device)
|
|
predicted_char = idx_to_char[top_index]
|
|
predicted_text += predicted_char
|
|
|
|
return predicted_text |