PIbd-32_Kashin_M.I_API_Cour.../model.py

39 lines
1.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pickle
import numpy as np
import tensorflow as tf
from keras.src.legacy.preprocessing.text import Tokenizer
from keras.src.utils import pad_sequences
# Загрузка модели
model = tf.keras.models.load_model('.//neural_network/create_lstm/model/best_model_lstm_negative.keras')
# Загрузка токенизатора
with open('.//neural_network/create_lstm/tokenization/tokenizer_lstm_lstm_negative.pickle', 'rb') as handle:
tokenizer = pickle.load(handle)
# Загрузка названий классов
with open('.//neural_network/create_lstm/class/class_names_lstm_negative.txt', 'r', encoding='utf-8') as file:
class_names = [line.strip() for line in file.readlines()]
def preprocess_text(text: str):
# Токенизация текста
sequences = tokenizer.texts_to_sequences([text])
# Преобразование последовательностей в фиксированной длины
padded_sequences = pad_sequences(sequences, maxlen=90) # 90 - длина последовательности, используемая при обучении
return padded_sequences
def predict_answer(question: str) -> str:
# Предобработка вопроса
print("Вопрос:", question)
input_data = preprocess_text(question)
print("Предобработанные данные:", input_data)
# Предсказание
prediction = model.predict(input_data)
print("Предсказание:", prediction)
# Определение индекса класса с наибольшей вероятностью
predicted_index = np.argmax(prediction[0])
# Получение имени класса
predicted_class = class_names[predicted_index]
print("Предсказанный класс:", predicted_class)
return predicted_class # Возвращаем имя предсказанного класса