diff --git a/model.py b/model.py index 5bfca2f..714f4eb 100644 --- a/model.py +++ b/model.py @@ -18,13 +18,27 @@ model_cnn_positive= tf.keras.models.load_model('.//neural_network/models/model/b # Загрузка токенизатора with open('.//neural_network/tokenization/tokenizer_negative.pickle', 'rb') as handle: - tokenizer = pickle.load(handle) + tokenizer_negative = pickle.load(handle) # Загрузка названий классов with open('.//neural_network/classification/class_names_negative.txt', 'r', encoding='utf-8') as file: - class_names = [line.strip() for line in file.readlines()] + class_names_negative = [line.strip() for line in file.readlines()] -def preprocess_text(text: str): +# Загрузка токенизатора +with open('.//neural_network/tokenization/tokenizer_positive.pickle', 'rb') as handle: + tokenizer_positive = pickle.load(handle) + +# Загрузка названий классов +with open('.//neural_network/classification/class_names_positive.txt', 'r', encoding='utf-8') as file: + class_names_positive = [line.strip() for line in file.readlines()] + +def preprocess_text(text: str, type_mood: TypeMood): + if type_mood == TypeMood.NEGATIVE: + tokenizer = tokenizer_negative + elif type_mood == TypeMood.POSITIVE: + tokenizer = tokenizer_positive + else: + raise ValueError("Unsupported model type") # Токенизация текста sequences = tokenizer.texts_to_sequences([text]) # Преобразование последовательностей в фиксированной длине @@ -34,21 +48,27 @@ def preprocess_text(text: str): def predict_answer(question: str, type_mood: TypeMood, type_model: TypeModel) -> str: if type_model == TypeModel.LSTM and type_mood == TypeMood.NEGATIVE: model = model_lstm_negative + class_names = class_names_negative elif type_model == TypeModel.LSTM and type_mood == TypeMood.POSITIVE: model = model_lstm_positive + class_names = class_names_positive elif type_model == TypeModel.GRU and type_mood == TypeMood.NEGATIVE: model = model_gru_negative + class_names = class_names_negative elif type_model == TypeModel.GRU and type_mood == TypeMood.POSITIVE: model = model_gru_positive + class_names = class_names_positive elif type_model == TypeModel.CNN and type_mood == TypeMood.NEGATIVE: model = model_cnn_negative + class_names = class_names_negative elif type_model == TypeModel.CNN and type_mood == TypeMood.POSITIVE: model = model_cnn_positive + class_names = class_names_positive else: raise ValueError("Unsupported model type") # Предобработка вопроса - input_data = preprocess_text(question) + input_data = preprocess_text(question, type_mood) # Предсказание prediction = model.predict(input_data)[0] # Получение имени класса