PIbd-32_Kashin_M.I_API_Cour.../repository.py

100 lines
4.3 KiB
Python
Raw Normal View History

2024-06-08 01:22:35 +04:00
from sqlalchemy import select, delete, func
from sqlalchemy.orm import joinedload
from datetime import datetime
import urllib.parse
from database import new_session, QuestionOrm
from enums import TypeMood, TypeModel
from gigachat import giga_token, get_chat_completion
from schemas import SQuestionAdd, SQuestion
from model import predict_answer
class QuestionRepository:
@classmethod
async def add_one(cls, data: SQuestionAdd, type_mood: TypeMood, type_model: TypeModel) -> int:
async with new_session() as session:
question_dict = data.model_dump()
# Декодирование URL-кодированных параметров
decoded_question = urllib.parse.unquote(question_dict["question"])
decoded_mood = urllib.parse.unquote(type_mood.value)
decoded_model = urllib.parse.unquote(type_model.value)
if decoded_model == "GigaChad":
# Предсказание ответа с помощью модели
predicted_class = get_chat_completion(giga_token, decoded_question)
prediction = 100.00
else:
# Предсказание ответа с помощью модели
predicted_class, prediction = predict_answer(decoded_question, decoded_mood, decoded_model)
if isinstance(prediction, float):
if prediction < 0.2:
answer = "Not Found"
else:
answer = predicted_class
else:
if max(prediction) < 0.2:
answer = "Not Found"
else:
answer = predicted_class
# Обновление декодированных значений в словаре
question_dict["question"] = decoded_question
question_dict["type_mood"] = decoded_mood
question_dict["type_model"] = decoded_model
question_dict["answer"] = answer
question_dict["question_time"] = datetime.now()
question = QuestionOrm(**question_dict)
session.add(question)
# Проверка количества записей для email_user
query = select(QuestionOrm).where(QuestionOrm.email_user == data.email_user)
result = await session.execute(query)
user_questions = result.scalars().all()
if len(user_questions) > 100:
# Удаление самой старой записи
oldest_question = min(user_questions, key=lambda q: q.question_time)
await session.delete(oldest_question)
await session.flush()
await session.commit()
return question.id, question.answer
@classmethod
async def find_all(cls) -> list[SQuestion]:
async with new_session() as session:
query = select(QuestionOrm)
result = await session.execute(query)
question_models = result.scalars().all()
question_schemas = [SQuestion.model_validate(question_model) for question_model in question_models]
return question_schemas
@classmethod
async def find_by_email(cls, email_user: str) -> list[SQuestion]:
async with new_session() as session:
query = select(QuestionOrm).where(QuestionOrm.email_user == email_user)
result = await session.execute(query)
question_models = result.scalars().all()
question_schemas = [SQuestion.model_validate(question_model) for question_model in question_models]
return question_schemas
2024-06-08 01:22:35 +04:00
@staticmethod
async def get_class_statistics() -> dict:
async with new_session() as session:
query = select(QuestionOrm.type_mood, QuestionOrm.answer, func.count(QuestionOrm.id)).group_by(QuestionOrm.type_mood, QuestionOrm.answer)
2024-06-08 01:22:35 +04:00
result = await session.execute(query)
statistics = {}
for row in result.fetchall():
mood = row[0]
answer = row[1]
count = row[2]
if mood not in statistics:
statistics[mood] = {}
if answer not in statistics[mood]:
statistics[mood][answer] = count
else:
statistics[mood][answer] += count
2024-06-08 01:22:35 +04:00
return statistics