import numpy as np import torch.nn as nn import torch.optim as optim import torch.utils.data as data import torch from model import CharModel def get_data(filename="wonderland.txt"): # загружаем датасет и приводим к нижнему регистру filename = "wonderland.txt" raw_text = open(filename, 'r', encoding='utf-8').read() raw_text = raw_text.lower() # делаем сопоставление текста с соответствующим ему значением chars = sorted(list(set(raw_text))) char_to_int = dict((c, i) for i, c in enumerate(chars)) # статистика обучаемых данных n_chars = len(raw_text) n_vocab = len(chars) print("Total Characters: ", n_chars) print("Total Vocab: ", n_vocab) # подготовка датасета seq_length = 100 dataX = [] dataY = [] for i in range(0, n_chars - seq_length, 1): seq_in = raw_text[i:i + seq_length] seq_out = raw_text[i + seq_length] dataX.append([char_to_int[char] for char in seq_in]) dataY.append(char_to_int[seq_out]) n_patterns = len(dataX) print("Total Patterns: ", n_patterns) # --- переводим данные к тензору, чтобы рабоать с ними внутри pytorch --- X = torch.tensor(dataX, dtype=torch.float32).reshape(n_patterns, seq_length, 1) X = X / float(n_vocab) y = torch.tensor(dataY) print(X.shape, y.shape) return X, y, char_to_int def main(): X, y, char_to_int = get_data() n_epochs = 40 batch_size = 128 model = CharModel() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"device: {device}") model.to(device) optimizer = optim.Adam(model.parameters()) loss_fn = nn.CrossEntropyLoss(reduction="sum") loader = data.DataLoader(data.TensorDataset(X, y), shuffle=True, batch_size=batch_size) best_model = None best_loss = np.inf for epoch in range(n_epochs): model.train() for X_batch, y_batch in loader: y_pred = model(X_batch.to(device)) loss = loss_fn(y_pred, y_batch.to(device)) optimizer.zero_grad() loss.backward() optimizer.step() # Validation model.eval() loss = 0 with torch.no_grad(): for X_batch, y_batch in loader: y_pred = model(X_batch.to(device)) loss += loss_fn(y_pred, y_batch.to(device)) if loss < best_loss: best_loss = loss best_model = model.state_dict() print("Epoch %d: Cross-entropy: %.4f" % (epoch, loss)) torch.save([best_model, char_to_int], "single-char.pth")