From 7c3020625c27ccb16144a185835fddea10284230 Mon Sep 17 00:00:00 2001 From: maksim Date: Sat, 25 May 2024 17:38:24 +0400 Subject: [PATCH] =?UTF-8?q?=D0=A3=D1=80=D0=B0,=20=D0=BC=D0=B8=D0=BD=D0=B8?= =?UTF-8?q?=20=D0=BF=D0=BE=D0=B1=D0=B5=D0=B4=D0=B0,=20=D1=8F=20=D0=BC?= =?UTF-8?q?=D0=BE=D0=B3=D1=83=20=D1=81=D0=BE=D1=85=D1=80=D0=B0=D0=BD=D1=8F?= =?UTF-8?q?=D1=82=D1=8C=20=D0=B2=D0=B5=D1=80=D0=BE=D1=8F=D1=82=D0=BD=D0=BE?= =?UTF-8?q?=D1=82=D1=81=D0=B8.=20=D0=9F=D0=B0=D0=B1=D0=B5=D0=B4=D0=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model.py | 33 +++++++++++++++++++++++++++++++++ repository.py | 7 +++++++ requirements.txt | Bin 572 -> 610 bytes 3 files changed, 40 insertions(+) create mode 100644 model.py diff --git a/model.py b/model.py new file mode 100644 index 0000000..42de820 --- /dev/null +++ b/model.py @@ -0,0 +1,33 @@ +import pickle +import numpy as np +import tensorflow as tf +from keras.src.legacy.preprocessing.text import Tokenizer +from keras.src.utils import pad_sequences + +# Загрузка модели +model = tf.keras.models.load_model('best_model_lstm_negative.keras') + +# Загрузка токенизатора +with open('tokenizer_lstm_lstm_negative.pickle', 'rb') as handle: + tokenizer = pickle.load(handle) + +def preprocess_text(text: str): + # Токенизация текста + sequences = tokenizer.texts_to_sequences([text]) + # Преобразование последовательностей в фиксированной длины + padded_sequences = pad_sequences(sequences, maxlen=90) # 90 - длина последовательности, используемая при обучении + return padded_sequences + +def predict_answer(question: str) -> str: + # Предобработка вопроса + print("Вопрос:", question) + input_data = preprocess_text(question) + print("Предобработанные данные:", input_data) + # Предсказание + prediction = model.predict(input_data) + print("Предсказание:", prediction) + # Преобразование предсказания в строку для сохранения + prediction_str = np.array2string(prediction[0], separator=',') + print("Строковое представление предсказания:", prediction_str) + return prediction_str # Возвращаем строковое представление предсказания + diff --git a/repository.py b/repository.py index 447e4b2..dd8281b 100644 --- a/repository.py +++ b/repository.py @@ -2,12 +2,19 @@ from sqlalchemy import select from database import new_session, QuestionOrm from schemas import SQuestionAdd, SQuestion +from model import predict_answer + class QuestionRepository: @classmethod async def add_one(cls, data: SQuestionAdd) -> int: async with new_session() as session: question_dict = data.model_dump() + + # Предсказание ответа с помощью модели + answer = predict_answer(question_dict["question"]) + + question_dict["answer"] = answer question = QuestionOrm(**question_dict) session.add(question) await session.flush() diff --git a/requirements.txt b/requirements.txt index 712cf652d3530a59c89d60688e1f3af52aa7ccc2..3b537bb31a7cb9d10c79fac4ef9e2cfa70ad9971 100644 GIT binary patch delta 46 vcmdnP@`z=F4U<|4Ln=caLoq`>LlHw7Lk^HGXW(VvV#o%H7cnFPWk5Ut3Y`g3 delta 7 OcmaFFvWI1Z4HEzilLD0h