maksim
1f5bd6bdbf
Завтра сделаю фильтрованый датасет по городам, чтобы отдавать также инфу из городов
39 lines
1.9 KiB
Python
39 lines
1.9 KiB
Python
import pickle
|
||
import numpy as np
|
||
import tensorflow as tf
|
||
from keras.src.legacy.preprocessing.text import Tokenizer
|
||
from keras.src.utils import pad_sequences
|
||
|
||
# Загрузка модели
|
||
model = tf.keras.models.load_model('.//neural_network/models/model/best_model_lstm_negative.keras')
|
||
|
||
# Загрузка токенизатора
|
||
with open('neural_network/models/tokenization/tokenizer_lstm_lstm_negative.pickle', 'rb') as handle:
|
||
tokenizer = pickle.load(handle)
|
||
|
||
# Загрузка названий классов
|
||
with open('neural_network/models/classification/class_names_lstm_negative.txt', 'r', encoding='utf-8') as file:
|
||
class_names = [line.strip() for line in file.readlines()]
|
||
|
||
def preprocess_text(text: str):
|
||
# Токенизация текста
|
||
sequences = tokenizer.texts_to_sequences([text])
|
||
# Преобразование последовательностей в фиксированной длины
|
||
padded_sequences = pad_sequences(sequences, maxlen=90) # 90 - длина последовательности, используемая при обучении
|
||
return padded_sequences
|
||
|
||
def predict_answer(question: str) -> str:
|
||
# Предобработка вопроса
|
||
print("Вопрос:", question)
|
||
input_data = preprocess_text(question)
|
||
print("Предобработанные данные:", input_data)
|
||
# Предсказание
|
||
prediction = model.predict(input_data)
|
||
print("Предсказание:", prediction)
|
||
# Определение индекса класса с наибольшей вероятностью
|
||
predicted_index = np.argmax(prediction[0])
|
||
# Получение имени класса
|
||
predicted_class = class_names[predicted_index]
|
||
print("Предсказанный класс:", predicted_class)
|
||
return predicted_class # Возвращаем имя предсказанного класса
|