2024-05-25 17:38:24 +04:00
|
|
|
import pickle
|
|
|
|
import numpy as np
|
|
|
|
import tensorflow as tf
|
|
|
|
from keras.src.utils import pad_sequences
|
|
|
|
|
2024-06-02 16:38:08 +04:00
|
|
|
from enums import TypeMood, TypeModel
|
|
|
|
|
2024-05-25 17:38:24 +04:00
|
|
|
# Загрузка модели
|
2024-06-02 16:38:08 +04:00
|
|
|
model_lstm_negative = tf.keras.models.load_model('.//neural_network/models/model/best_model_lstm_negative.keras')
|
|
|
|
model_gru_negative = tf.keras.models.load_model('.//neural_network/models/model/best_model_gru_negative.keras')
|
|
|
|
model_cnn_negative = tf.keras.models.load_model('.//neural_network/models/model/best_model_cnn_negative.keras')
|
|
|
|
|
|
|
|
model_lstm_positive = tf.keras.models.load_model('.//neural_network/models/model/best_model_lstm_positive.keras')
|
|
|
|
model_gru_positive = tf.keras.models.load_model('.//neural_network/models/model/best_model_gru_positive.keras')
|
|
|
|
model_cnn_positive= tf.keras.models.load_model('.//neural_network/models/model/best_model_cnn_positive.keras')
|
|
|
|
|
2024-05-25 17:38:24 +04:00
|
|
|
|
|
|
|
# Загрузка токенизатора
|
2024-06-02 15:51:14 +04:00
|
|
|
with open('.//neural_network/tokenization/tokenizer_negative.pickle', 'rb') as handle:
|
2024-06-02 16:58:03 +04:00
|
|
|
tokenizer_negative = pickle.load(handle)
|
2024-05-25 17:38:24 +04:00
|
|
|
|
2024-05-27 23:52:33 +04:00
|
|
|
# Загрузка названий классов
|
2024-06-02 15:51:14 +04:00
|
|
|
with open('.//neural_network/classification/class_names_negative.txt', 'r', encoding='utf-8') as file:
|
2024-06-02 16:58:03 +04:00
|
|
|
class_names_negative = [line.strip() for line in file.readlines()]
|
2024-05-27 23:52:33 +04:00
|
|
|
|
2024-06-02 16:58:03 +04:00
|
|
|
# Загрузка токенизатора
|
|
|
|
with open('.//neural_network/tokenization/tokenizer_positive.pickle', 'rb') as handle:
|
|
|
|
tokenizer_positive = pickle.load(handle)
|
|
|
|
|
|
|
|
# Загрузка названий классов
|
|
|
|
with open('.//neural_network/classification/class_names_positive.txt', 'r', encoding='utf-8') as file:
|
|
|
|
class_names_positive = [line.strip() for line in file.readlines()]
|
|
|
|
|
|
|
|
def preprocess_text(text: str, type_mood: TypeMood):
|
|
|
|
if type_mood == TypeMood.NEGATIVE:
|
|
|
|
tokenizer = tokenizer_negative
|
|
|
|
elif type_mood == TypeMood.POSITIVE:
|
|
|
|
tokenizer = tokenizer_positive
|
|
|
|
else:
|
|
|
|
raise ValueError("Unsupported model type")
|
2024-05-25 17:38:24 +04:00
|
|
|
# Токенизация текста
|
|
|
|
sequences = tokenizer.texts_to_sequences([text])
|
2024-06-02 16:38:08 +04:00
|
|
|
# Преобразование последовательностей в фиксированной длине
|
2024-05-25 17:38:24 +04:00
|
|
|
padded_sequences = pad_sequences(sequences, maxlen=90) # 90 - длина последовательности, используемая при обучении
|
|
|
|
return padded_sequences
|
|
|
|
|
2024-06-02 16:38:08 +04:00
|
|
|
def predict_answer(question: str, type_mood: TypeMood, type_model: TypeModel) -> str:
|
|
|
|
if type_model == TypeModel.LSTM and type_mood == TypeMood.NEGATIVE:
|
|
|
|
model = model_lstm_negative
|
2024-06-02 16:58:03 +04:00
|
|
|
class_names = class_names_negative
|
2024-06-02 16:38:08 +04:00
|
|
|
elif type_model == TypeModel.LSTM and type_mood == TypeMood.POSITIVE:
|
|
|
|
model = model_lstm_positive
|
2024-06-02 16:58:03 +04:00
|
|
|
class_names = class_names_positive
|
2024-06-02 16:38:08 +04:00
|
|
|
elif type_model == TypeModel.GRU and type_mood == TypeMood.NEGATIVE:
|
|
|
|
model = model_gru_negative
|
2024-06-02 16:58:03 +04:00
|
|
|
class_names = class_names_negative
|
2024-06-02 16:38:08 +04:00
|
|
|
elif type_model == TypeModel.GRU and type_mood == TypeMood.POSITIVE:
|
|
|
|
model = model_gru_positive
|
2024-06-02 16:58:03 +04:00
|
|
|
class_names = class_names_positive
|
2024-06-02 16:38:08 +04:00
|
|
|
elif type_model == TypeModel.CNN and type_mood == TypeMood.NEGATIVE:
|
|
|
|
model = model_cnn_negative
|
2024-06-02 16:58:03 +04:00
|
|
|
class_names = class_names_negative
|
2024-06-02 16:38:08 +04:00
|
|
|
elif type_model == TypeModel.CNN and type_mood == TypeMood.POSITIVE:
|
|
|
|
model = model_cnn_positive
|
2024-06-02 16:58:03 +04:00
|
|
|
class_names = class_names_positive
|
2024-06-02 16:38:08 +04:00
|
|
|
else:
|
|
|
|
raise ValueError("Unsupported model type")
|
|
|
|
|
2024-05-25 17:38:24 +04:00
|
|
|
# Предобработка вопроса
|
2024-06-02 16:58:03 +04:00
|
|
|
input_data = preprocess_text(question, type_mood)
|
2024-05-25 17:38:24 +04:00
|
|
|
# Предсказание
|
2024-06-02 16:38:08 +04:00
|
|
|
prediction = model.predict(input_data)[0]
|
2024-05-27 23:52:33 +04:00
|
|
|
# Получение имени класса
|
2024-06-02 16:38:08 +04:00
|
|
|
predicted_index = np.argmax(prediction)
|
2024-05-27 23:52:33 +04:00
|
|
|
predicted_class = class_names[predicted_index]
|
2024-06-02 16:38:08 +04:00
|
|
|
return predicted_class, prediction
|