IIS_2023_1/malkova_anastasia_lab_7/main.py

25 lines
669 B
Python
Raw Normal View History

2023-11-17 01:50:49 +04:00
from config import PREDICTION_LEN_START
from data import text_to_seq
from generation import evaluate
from train import create_parameters, training
if __name__ == '__main__':
sequence, char_to_idx, idx_to_char = text_to_seq()
criterion, scheduler, n_epochs, loss_avg, device, model, optimizer = create_parameters(idx_to_char)
training(n_epochs, model, sequence, device, criterion, optimizer, loss_avg, scheduler, char_to_idx, idx_to_char)
model.eval()
print(evaluate(
model,
char_to_idx,
idx_to_char,
device = device,
temp=0.3,
prediction_len=PREDICTION_LEN_START,
start_text='. '
)
)