IIS_2023_1/romanova_adelina_lab_7/train.py

86 lines
2.7 KiB
Python
Raw Normal View History

2023-12-25 01:19:51 +04:00
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch
from model import CharModel
def get_data(filename="wonderland.txt"):
# загружаем датасет и приводим к нижнему регистру
filename = "wonderland.txt"
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = raw_text.lower()
# делаем сопоставление текста с соответствующим ему значением
chars = sorted(list(set(raw_text)))
char_to_int = dict((c, i) for i, c in enumerate(chars))
# статистика обучаемых данных
n_chars = len(raw_text)
n_vocab = len(chars)
print("Total Characters: ", n_chars)
print("Total Vocab: ", n_vocab)
# подготовка датасета
seq_length = 100
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):
seq_in = raw_text[i:i + seq_length]
seq_out = raw_text[i + seq_length]
dataX.append([char_to_int[char] for char in seq_in])
dataY.append(char_to_int[seq_out])
n_patterns = len(dataX)
print("Total Patterns: ", n_patterns)
# --- переводим данные к тензору, чтобы рабоать с ними внутри pytorch ---
X = torch.tensor(dataX, dtype=torch.float32).reshape(n_patterns, seq_length, 1)
X = X / float(n_vocab)
y = torch.tensor(dataY)
print(X.shape, y.shape)
return X, y, char_to_int
def main():
X, y, char_to_int = get_data()
n_epochs = 40
batch_size = 128
model = CharModel()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")
model.to(device)
optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss(reduction="sum")
loader = data.DataLoader(data.TensorDataset(X, y), shuffle=True, batch_size=batch_size)
best_model = None
best_loss = np.inf
for epoch in range(n_epochs):
model.train()
for X_batch, y_batch in loader:
y_pred = model(X_batch.to(device))
loss = loss_fn(y_pred, y_batch.to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Validation
model.eval()
loss = 0
with torch.no_grad():
for X_batch, y_batch in loader:
y_pred = model(X_batch.to(device))
loss += loss_fn(y_pred, y_batch.to(device))
if loss < best_loss:
best_loss = loss
best_model = model.state_dict()
print("Epoch %d: Cross-entropy: %.4f" % (epoch, loss))
torch.save([best_model, char_to_int], "single-char.pth")