PIbd-32_Kashin_M.I_API_Cour.../repository.py

100 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sqlalchemy import select, delete, func
from sqlalchemy.orm import joinedload
from datetime import datetime
import urllib.parse
from database import new_session, QuestionOrm
from enums import TypeMood, TypeModel
from gigachat import giga_token, get_chat_completion
from schemas import SQuestionAdd, SQuestion
from model import predict_answer
class QuestionRepository:
@classmethod
async def add_one(cls, data: SQuestionAdd, type_mood: TypeMood, type_model: TypeModel) -> int:
async with new_session() as session:
question_dict = data.model_dump()
# Декодирование URL-кодированных параметров
decoded_question = urllib.parse.unquote(question_dict["question"])
decoded_mood = urllib.parse.unquote(type_mood.value)
decoded_model = urllib.parse.unquote(type_model.value)
if decoded_model == "GigaChad":
# Предсказание ответа с помощью модели
predicted_class = get_chat_completion(giga_token, decoded_question)
prediction = 100.00
else:
# Предсказание ответа с помощью модели
predicted_class, prediction = predict_answer(decoded_question, decoded_mood, decoded_model)
if isinstance(prediction, float):
if prediction < 0.2:
answer = "Not Found"
else:
answer = predicted_class
else:
if max(prediction) < 0.2:
answer = "Not Found"
else:
answer = predicted_class
# Обновление декодированных значений в словаре
question_dict["question"] = decoded_question
question_dict["type_mood"] = decoded_mood
question_dict["type_model"] = decoded_model
question_dict["answer"] = answer
question_dict["question_time"] = datetime.now()
question = QuestionOrm(**question_dict)
session.add(question)
# Проверка количества записей для email_user
query = select(QuestionOrm).where(QuestionOrm.email_user == data.email_user)
result = await session.execute(query)
user_questions = result.scalars().all()
if len(user_questions) > 100:
# Удаление самой старой записи
oldest_question = min(user_questions, key=lambda q: q.question_time)
await session.delete(oldest_question)
await session.flush()
await session.commit()
return question.id, question.answer
@classmethod
async def find_all(cls) -> list[SQuestion]:
async with new_session() as session:
query = select(QuestionOrm)
result = await session.execute(query)
question_models = result.scalars().all()
question_schemas = [SQuestion.model_validate(question_model) for question_model in question_models]
return question_schemas
@classmethod
async def find_by_email(cls, email_user: str) -> list[SQuestion]:
async with new_session() as session:
query = select(QuestionOrm).where(QuestionOrm.email_user == email_user)
result = await session.execute(query)
question_models = result.scalars().all()
question_schemas = [SQuestion.model_validate(question_model) for question_model in question_models]
return question_schemas
@staticmethod
async def get_class_statistics() -> dict:
async with new_session() as session:
query = select(QuestionOrm.type_mood, QuestionOrm.answer, func.count(QuestionOrm.id)).group_by(QuestionOrm.type_mood, QuestionOrm.answer)
result = await session.execute(query)
statistics = {}
for row in result.fetchall():
mood = row[0]
answer = row[1]
count = row[2]
if mood not in statistics:
statistics[mood] = {}
if answer not in statistics[mood]:
statistics[mood][answer] = count
else:
statistics[mood][answer] += count
return statistics