46 lines
1.5 KiB
Python
46 lines
1.5 KiB
Python
import torch
|
|
from model import CharModel
|
|
import numpy as np
|
|
|
|
if __name__ == "__main__":
|
|
best_model, char_to_int = torch.load("single-char.pth")
|
|
n_vocab = len(char_to_int)
|
|
int_to_char = dict((i, c) for c, i in char_to_int.items())
|
|
|
|
|
|
model = CharModel()
|
|
model.load_state_dict(best_model)
|
|
|
|
# randomly generate a prompt
|
|
filename = "wonderland.txt"
|
|
seq_length = 100
|
|
raw_text = open(filename, 'r', encoding='utf-8').read()
|
|
raw_text = raw_text.lower()
|
|
|
|
start = np.random.randint(0, len(raw_text)-seq_length)
|
|
prompt = raw_text[start:start+seq_length]
|
|
pattern = [char_to_int[c] for c in prompt]
|
|
|
|
model.eval()
|
|
print(f'Prompt:\n{prompt}')
|
|
print("==="*15, "Сгенерированный результ", "==="*15, sep=" ")
|
|
|
|
with torch.no_grad():
|
|
for i in range(1000):
|
|
# format input array of int into PyTorch tensor
|
|
x = np.reshape(pattern, (1, len(pattern), 1)) / float(n_vocab)
|
|
x = torch.tensor(x, dtype=torch.float32)
|
|
# generate logits as output from the model
|
|
prediction = model(x)
|
|
# convert logits into one character
|
|
index = int(prediction.argmax())
|
|
result = int_to_char[index]
|
|
print(result, end="")
|
|
# append the new character into the prompt for the next iteration
|
|
pattern.append(index)
|
|
pattern = pattern[1:]
|
|
|
|
print()
|
|
print("==="*30)
|
|
print("Done.")
|
|
|