import pickle
import numpy as np
import tensorflow as tf
from keras.src.legacy.preprocessing.text import Tokenizer
from keras.src.utils import pad_sequences

from enums import TypeMood, TypeModel

# Загрузка модели
model_lstm_negative = tf.keras.models.load_model('.//neural_network/models/model/best_model_lstm_negative.keras')
model_gru_negative = tf.keras.models.load_model('.//neural_network/models/model/best_model_gru_negative.keras')
model_cnn_negative = tf.keras.models.load_model('.//neural_network/models/model/best_model_cnn_negative.keras')

model_lstm_positive = tf.keras.models.load_model('.//neural_network/models/model/best_model_lstm_positive.keras')
model_gru_positive = tf.keras.models.load_model('.//neural_network/models/model/best_model_gru_positive.keras')
model_cnn_positive= tf.keras.models.load_model('.//neural_network/models/model/best_model_cnn_positive.keras')


# Загрузка токенизатора
with open('.//neural_network/tokenization/tokenizer_negative.pickle', 'rb') as handle:
    tokenizer_negative = pickle.load(handle)

# Загрузка названий классов
with open('.//neural_network/classification/class_names_negative.txt', 'r', encoding='utf-8') as file:
    class_names_negative = [line.strip() for line in file.readlines()]

# Загрузка токенизатора
with open('.//neural_network/tokenization/tokenizer_positive.pickle', 'rb') as handle:
    tokenizer_positive = pickle.load(handle)

# Загрузка названий классов
with open('.//neural_network/classification/class_names_positive.txt', 'r', encoding='utf-8') as file:
    class_names_positive = [line.strip() for line in file.readlines()]

def preprocess_text(text: str, type_mood: TypeMood):
    if type_mood == TypeMood.NEGATIVE:
        tokenizer = tokenizer_negative
    elif type_mood == TypeMood.POSITIVE:
        tokenizer = tokenizer_positive
    else:
        raise ValueError("Unsupported model type")
    # Токенизация текста
    sequences = tokenizer.texts_to_sequences([text])
    # Преобразование последовательностей в фиксированной длине
    padded_sequences = pad_sequences(sequences, maxlen=90)  # 90 - длина последовательности, используемая при обучении
    return padded_sequences

def predict_answer(question: str, type_mood: TypeMood, type_model: TypeModel) -> str:
    if type_model == TypeModel.LSTM and type_mood == TypeMood.NEGATIVE:
        model = model_lstm_negative
        class_names = class_names_negative
    elif type_model == TypeModel.LSTM and type_mood == TypeMood.POSITIVE:
        model = model_lstm_positive
        class_names = class_names_positive
    elif type_model == TypeModel.GRU and type_mood == TypeMood.NEGATIVE:
        model = model_gru_negative
        class_names = class_names_negative
    elif type_model == TypeModel.GRU and type_mood == TypeMood.POSITIVE:
        model = model_gru_positive
        class_names = class_names_positive
    elif type_model == TypeModel.CNN and type_mood == TypeMood.NEGATIVE:
        model = model_cnn_negative
        class_names = class_names_negative
    elif type_model == TypeModel.CNN and type_mood == TypeMood.POSITIVE:
        model = model_cnn_positive
        class_names = class_names_positive
    else:
        raise ValueError("Unsupported model type")

    # Предобработка вопроса
    input_data = preprocess_text(question, type_mood)
    # Предсказание
    prediction = model.predict(input_data)[0]
    # Получение имени класса
    predicted_index = np.argmax(prediction)
    predicted_class = class_names[predicted_index]
    return predicted_class, prediction