Добавил гигачада в метод. Очень кринж, завтра надо фулл рефакторинг сделать, перед запиский
This commit is contained in:
parent
27148b8638
commit
0cfc3d36aa
2
enums.py
2
enums.py
@ -3,8 +3,10 @@ from enum import Enum
|
||||
class TypeMood(str, Enum):
|
||||
POSITIVE = "Positive"
|
||||
NEGATIVE = "Negative"
|
||||
NEUTRAL = "Neutral"
|
||||
|
||||
class TypeModel(str, Enum):
|
||||
LSTM = "LSTM"
|
||||
GRU = "GRU"
|
||||
CNN = "CNN"
|
||||
GIGACHAD = "GigaChad"
|
||||
|
@ -5,6 +5,7 @@ import urllib.parse
|
||||
|
||||
from database import new_session, QuestionOrm
|
||||
from enums import TypeMood, TypeModel
|
||||
from gigachat import giga_token, get_chat_completion
|
||||
from schemas import SQuestionAdd, SQuestion
|
||||
from model import predict_answer
|
||||
|
||||
@ -19,14 +20,24 @@ class QuestionRepository:
|
||||
decoded_mood = urllib.parse.unquote(type_mood.value)
|
||||
decoded_model = urllib.parse.unquote(type_model.value)
|
||||
|
||||
# Предсказание ответа с помощью модели
|
||||
predicted_class, prediction = predict_answer(decoded_question, decoded_mood, decoded_model)
|
||||
|
||||
# Проверка вероятностей классов
|
||||
if max(prediction) < 0.2:
|
||||
answer = "Not Found"
|
||||
if decoded_model == "GigaChad":
|
||||
# Предсказание ответа с помощью модели
|
||||
predicted_class = get_chat_completion(giga_token, decoded_question)
|
||||
prediction = 100.00
|
||||
else:
|
||||
answer = predicted_class
|
||||
# Предсказание ответа с помощью модели
|
||||
predicted_class, prediction = predict_answer(decoded_question, decoded_mood, decoded_model)
|
||||
|
||||
if isinstance(prediction, float):
|
||||
if prediction < 0.2:
|
||||
answer = "Not Found"
|
||||
else:
|
||||
answer = predicted_class
|
||||
else:
|
||||
if max(prediction) < 0.2:
|
||||
answer = "Not Found"
|
||||
else:
|
||||
answer = predicted_class
|
||||
|
||||
# Обновление декодированных значений в словаре
|
||||
question_dict["question"] = decoded_question
|
||||
|
@ -31,11 +31,3 @@ async def get_questions() -> list[SQuestion]:
|
||||
async def get_questions_by_email(email_user: str) -> list[SQuestion]:
|
||||
questions = await QuestionRepository.find_by_email(email_user)
|
||||
return questions
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def add_question(question: str):
|
||||
# Получение ответа от нейронной сети
|
||||
neural_response = get_chat_completion(giga_token, question)
|
||||
|
||||
return {"answer": neural_response}
|
||||
|
Loading…
Reference in New Issue
Block a user