25 lines
669 B
Python
25 lines
669 B
Python
from config import PREDICTION_LEN_START
|
|
from data import text_to_seq
|
|
from generation import evaluate
|
|
from train import create_parameters, training
|
|
|
|
if __name__ == '__main__':
|
|
sequence, char_to_idx, idx_to_char = text_to_seq()
|
|
criterion, scheduler, n_epochs, loss_avg, device, model, optimizer = create_parameters(idx_to_char)
|
|
|
|
training(n_epochs, model, sequence, device, criterion, optimizer, loss_avg, scheduler, char_to_idx, idx_to_char)
|
|
|
|
model.eval()
|
|
|
|
print(evaluate(
|
|
model,
|
|
char_to_idx,
|
|
idx_to_char,
|
|
device = device,
|
|
temp=0.3,
|
|
prediction_len=PREDICTION_LEN_START,
|
|
start_text='. '
|
|
)
|
|
)
|
|
|