from sqlalchemy import select, delete, func from sqlalchemy.orm import joinedload from datetime import datetime import urllib.parse from database import new_session, QuestionOrm from enums import TypeMood, TypeModel from schemas import SQuestionAdd, SQuestion from model import predict_answer class QuestionRepository: @classmethod async def add_one(cls, data: SQuestionAdd, type_mood: TypeMood, type_model: TypeModel) -> int: async with new_session() as session: question_dict = data.model_dump() # Декодирование URL-кодированных параметров decoded_question = urllib.parse.unquote(question_dict["question"]) decoded_mood = urllib.parse.unquote(type_mood.value) decoded_model = urllib.parse.unquote(type_model.value) # Предсказание ответа с помощью модели predicted_class, prediction = predict_answer(decoded_question, decoded_mood, decoded_model) # Проверка вероятностей классов if max(prediction) < 0.2: answer = "Not Found" else: answer = predicted_class # Обновление декодированных значений в словаре question_dict["question"] = decoded_question question_dict["type_mood"] = decoded_mood question_dict["type_model"] = decoded_model question_dict["answer"] = answer question_dict["question_time"] = datetime.now() question = QuestionOrm(**question_dict) session.add(question) # Проверка количества записей для email_user query = select(QuestionOrm).where(QuestionOrm.email_user == data.email_user) result = await session.execute(query) user_questions = result.scalars().all() if len(user_questions) > 100: # Удаление самой старой записи oldest_question = min(user_questions, key=lambda q: q.question_time) await session.delete(oldest_question) await session.flush() await session.commit() return question.id, question.answer @classmethod async def find_all(cls) -> list[SQuestion]: async with new_session() as session: query = select(QuestionOrm) result = await session.execute(query) question_models = result.scalars().all() question_schemas = [SQuestion.model_validate(question_model) for question_model in question_models] return question_schemas @classmethod async def find_by_email(cls, email_user: str) -> list[SQuestion]: async with new_session() as session: query = select(QuestionOrm).where(QuestionOrm.email_user == email_user) result = await session.execute(query) question_models = result.scalars().all() question_schemas = [SQuestion.model_validate(question_model) for question_model in question_models] return question_schemas @staticmethod async def get_class_statistics() -> dict: async with new_session() as session: query = select(QuestionOrm.type_mood, QuestionOrm.answer, func.count(QuestionOrm.id)).group_by(QuestionOrm.type_mood, QuestionOrm.answer) result = await session.execute(query) statistics = {} for row in result.fetchall(): mood = row[0] answer = row[1] count = row[2] if mood not in statistics: statistics[mood] = {} if answer not in statistics[mood]: statistics[mood][answer] = count else: statistics[mood][answer] += count return statistics