PIbd-32_Kashin_M.I_API_Cour.../model.py

76 lines
3.7 KiB
Python
Raw Normal View History

import pickle
import numpy as np
import tensorflow as tf
from keras.src.utils import pad_sequences
from enums import TypeMood, TypeModel
# Загрузка модели
model_lstm_negative = tf.keras.models.load_model('.//neural_network/models/model/best_model_lstm_negative.keras')
model_gru_negative = tf.keras.models.load_model('.//neural_network/models/model/best_model_gru_negative.keras')
model_cnn_negative = tf.keras.models.load_model('.//neural_network/models/model/best_model_cnn_negative.keras')
model_lstm_positive = tf.keras.models.load_model('.//neural_network/models/model/best_model_lstm_positive.keras')
model_gru_positive = tf.keras.models.load_model('.//neural_network/models/model/best_model_gru_positive.keras')
model_cnn_positive= tf.keras.models.load_model('.//neural_network/models/model/best_model_cnn_positive.keras')
# Загрузка токенизатора
with open('.//neural_network/tokenization/tokenizer_negative.pickle', 'rb') as handle:
tokenizer_negative = pickle.load(handle)
2024-05-27 23:52:33 +04:00
# Загрузка названий классов
with open('.//neural_network/classification/class_names_negative.txt', 'r', encoding='utf-8') as file:
class_names_negative = [line.strip() for line in file.readlines()]
2024-05-27 23:52:33 +04:00
# Загрузка токенизатора
with open('.//neural_network/tokenization/tokenizer_positive.pickle', 'rb') as handle:
tokenizer_positive = pickle.load(handle)
# Загрузка названий классов
with open('.//neural_network/classification/class_names_positive.txt', 'r', encoding='utf-8') as file:
class_names_positive = [line.strip() for line in file.readlines()]
def preprocess_text(text: str, type_mood: TypeMood):
if type_mood == TypeMood.NEGATIVE:
tokenizer = tokenizer_negative
elif type_mood == TypeMood.POSITIVE:
tokenizer = tokenizer_positive
else:
raise ValueError("Unsupported model type")
# Токенизация текста
sequences = tokenizer.texts_to_sequences([text])
# Преобразование последовательностей в фиксированной длине
padded_sequences = pad_sequences(sequences, maxlen=90) # 90 - длина последовательности, используемая при обучении
return padded_sequences
def predict_answer(question: str, type_mood: TypeMood, type_model: TypeModel) -> str:
if type_model == TypeModel.LSTM and type_mood == TypeMood.NEGATIVE:
model = model_lstm_negative
class_names = class_names_negative
elif type_model == TypeModel.LSTM and type_mood == TypeMood.POSITIVE:
model = model_lstm_positive
class_names = class_names_positive
elif type_model == TypeModel.GRU and type_mood == TypeMood.NEGATIVE:
model = model_gru_negative
class_names = class_names_negative
elif type_model == TypeModel.GRU and type_mood == TypeMood.POSITIVE:
model = model_gru_positive
class_names = class_names_positive
elif type_model == TypeModel.CNN and type_mood == TypeMood.NEGATIVE:
model = model_cnn_negative
class_names = class_names_negative
elif type_model == TypeModel.CNN and type_mood == TypeMood.POSITIVE:
model = model_cnn_positive
class_names = class_names_positive
else:
raise ValueError("Unsupported model type")
# Предобработка вопроса
input_data = preprocess_text(question, type_mood)
# Предсказание
prediction = model.predict(input_data)[0]
2024-05-27 23:52:33 +04:00
# Получение имени класса
predicted_index = np.argmax(prediction)
2024-05-27 23:52:33 +04:00
predicted_class = class_names[predicted_index]
return predicted_class, prediction