import torch from model import CharModel import numpy as np if __name__ == "__main__": best_model, char_to_int = torch.load("single-char.pth") n_vocab = len(char_to_int) int_to_char = dict((i, c) for c, i in char_to_int.items()) model = CharModel() model.load_state_dict(best_model) # randomly generate a prompt filename = "wonderland.txt" seq_length = 100 raw_text = open(filename, 'r', encoding='utf-8').read() raw_text = raw_text.lower() start = np.random.randint(0, len(raw_text)-seq_length) prompt = raw_text[start:start+seq_length] pattern = [char_to_int[c] for c in prompt] model.eval() print(f'Prompt:\n{prompt}') print("==="*15, "Сгенерированный результ", "==="*15, sep=" ") with torch.no_grad(): for i in range(1000): # format input array of int into PyTorch tensor x = np.reshape(pattern, (1, len(pattern), 1)) / float(n_vocab) x = torch.tensor(x, dtype=torch.float32) # generate logits as output from the model prediction = model(x) # convert logits into one character index = int(prediction.argmax()) result = int_to_char[index] print(result, end="") # append the new character into the prompt for the next iteration pattern.append(index) pattern = pattern[1:] print() print("==="*30) print("Done.")