IIS_2023_1/romanova_adelina_lab_7/generate.py

46 lines
1.5 KiB
Python
Raw Normal View History

2023-12-25 01:19:51 +04:00
import torch
from model import CharModel
import numpy as np
if __name__ == "__main__":
best_model, char_to_int = torch.load("single-char.pth")
n_vocab = len(char_to_int)
int_to_char = dict((i, c) for c, i in char_to_int.items())
model = CharModel()
model.load_state_dict(best_model)
# randomly generate a prompt
filename = "wonderland.txt"
seq_length = 100
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = raw_text.lower()
start = np.random.randint(0, len(raw_text)-seq_length)
prompt = raw_text[start:start+seq_length]
pattern = [char_to_int[c] for c in prompt]
model.eval()
print(f'Prompt:\n{prompt}')
print("==="*15, "Сгенерированный результ", "==="*15, sep=" ")
with torch.no_grad():
for i in range(1000):
# format input array of int into PyTorch tensor
x = np.reshape(pattern, (1, len(pattern), 1)) / float(n_vocab)
x = torch.tensor(x, dtype=torch.float32)
# generate logits as output from the model
prediction = model(x)
# convert logits into one character
index = int(prediction.argmax())
result = int_to_char[index]
print(result, end="")
# append the new character into the prompt for the next iteration
pattern.append(index)
pattern = pattern[1:]
print()
print("==="*30)
print("Done.")