Добавил гигачада в метод. Очень кринж, завтра надо фулл рефакторинг сделать, перед запиский

This commit is contained in:
maksim 2024-06-17 00:19:54 +04:00
parent 27148b8638
commit 0cfc3d36aa
3 changed files with 20 additions and 15 deletions

View File

@ -3,8 +3,10 @@ from enum import Enum
class TypeMood(str, Enum):
POSITIVE = "Positive"
NEGATIVE = "Negative"
NEUTRAL = "Neutral"
class TypeModel(str, Enum):
LSTM = "LSTM"
GRU = "GRU"
CNN = "CNN"
GIGACHAD = "GigaChad"

View File

@ -5,6 +5,7 @@ import urllib.parse
from database import new_session, QuestionOrm
from enums import TypeMood, TypeModel
from gigachat import giga_token, get_chat_completion
from schemas import SQuestionAdd, SQuestion
from model import predict_answer
@ -19,14 +20,24 @@ class QuestionRepository:
decoded_mood = urllib.parse.unquote(type_mood.value)
decoded_model = urllib.parse.unquote(type_model.value)
# Предсказание ответа с помощью модели
predicted_class, prediction = predict_answer(decoded_question, decoded_mood, decoded_model)
# Проверка вероятностей классов
if max(prediction) < 0.2:
answer = "Not Found"
if decoded_model == "GigaChad":
# Предсказание ответа с помощью модели
predicted_class = get_chat_completion(giga_token, decoded_question)
prediction = 100.00
else:
answer = predicted_class
# Предсказание ответа с помощью модели
predicted_class, prediction = predict_answer(decoded_question, decoded_mood, decoded_model)
if isinstance(prediction, float):
if prediction < 0.2:
answer = "Not Found"
else:
answer = predicted_class
else:
if max(prediction) < 0.2:
answer = "Not Found"
else:
answer = predicted_class
# Обновление декодированных значений в словаре
question_dict["question"] = decoded_question

View File

@ -31,11 +31,3 @@ async def get_questions() -> list[SQuestion]:
async def get_questions_by_email(email_user: str) -> list[SQuestion]:
questions = await QuestionRepository.find_by_email(email_user)
return questions
@router.post("/chat")
async def add_question(question: str):
# Получение ответа от нейронной сети
neural_response = get_chat_completion(giga_token, question)
return {"answer": neural_response}