import numpy as np
from keras_preprocessing.sequence import pad_sequences
from keras_preprocessing.text import Tokenizer
from keras.models import Sequential
from keras.layers import Dense, LSTM, Embedding, Dropout
from keras.callbacks import ModelCheckpoint


def recreate_model(predictors, labels, model, filepath, epoch_num):
    model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    append_epochs(predictors, labels, model, epoch_num)


def append_epochs(predictors, labels, model, filepath, epoch_num):
    checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')
    desired_callbacks = [checkpoint]
    model.fit(predictors, labels, epochs=epoch_num, verbose=1, callbacks=desired_callbacks)


def generate_text(tokenizer, seed_text, next_words, model, max_seq_length):
    for _ in range(next_words):
        token_list = tokenizer.texts_to_sequences([seed_text])[0]
        token_list = pad_sequences([token_list], maxlen=max_seq_length - 1, padding='pre')
        predicted = np.argmax(model.predict(token_list), axis=-1)
        output_word = ""
        for word, index in tokenizer.word_index.items():
            if index == predicted:
                output_word = word
                break
        seed_text += " " + output_word
    return seed_text


def start():
    flag = -1
    while flag < 1 or flag > 2:
        flag = int(input("Select model and text (1 - eng, 2 - ru): "))

    if flag == 1:
        file = open("data.txt").read()
        filepath = "model_eng.hdf5"
    elif flag == 2:
        file = open("rus_data.txt").read()
        filepath = "model_rus.hdf5"
    else:
        exit(1)

    tokenizer = Tokenizer()
    tokenizer.fit_on_texts([file])
    words_count = len(tokenizer.word_index) + 1

    input_sequences = []
    for line in file.split('\n'):
        token_list = tokenizer.texts_to_sequences([line])[0]
        for i in range(1, len(token_list)):
            n_gram_sequence = token_list[:i + 1]
            input_sequences.append(n_gram_sequence)

    max_seq_length = max([len(x) for x in input_sequences])
    input_sequences = pad_sequences(input_sequences, maxlen=max_seq_length, padding='pre')

    predictors, labels = input_sequences[:, :-1], input_sequences[:, -1]

    model = Sequential()
    model.add(Embedding(words_count, 100, input_length=max_seq_length - 1))
    model.add(LSTM(150))
    model.add(Dropout(0.15))
    model.add(Dense(words_count, activation='softmax'))

    flag = input("Do you want to recreate the model ? (print yes): ")
    if flag == 'yes':
        flag = input("Are you sure? (print yes): ")
        if flag == 'yes':
            num = int(input("Select number of epoch: "))
            if 0 < num < 100:
                recreate_model(predictors, labels, model, filepath, num)

    model.load_weights(filepath)
    model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

    flag = input("Do you want to train the model ? (print yes): ")
    if flag == 'yes':
        flag = input("Are you sure? (print yes): ")
        if flag == 'yes':
            num = int(input("Select number of epoch: "))
            if 0 < num < 100:
                append_epochs(predictors, labels, model, filepath, num)

    flag = 'y'
    while flag == 'y':
        seed = input("Enter seed: ")
        print(generate_text(tokenizer, seed, 25, model, max_seq_length))
        flag = input("Continue? (print \'y\'): ")


start()