from sqlalchemy import select, delete
from sqlalchemy.orm import joinedload
from datetime import datetime

from database import new_session, QuestionOrm
from enums import TypeMood, TypeModel
from schemas import SQuestionAdd, SQuestion
from model import predict_answer

class QuestionRepository:
    @classmethod
    async def add_one(cls, data: SQuestionAdd, type_mood: TypeMood, type_model: TypeModel) -> int:
        async with new_session() as session:
            question_dict = data.model_dump()

            # Предсказание ответа с помощью модели
            predicted_class, prediction = predict_answer(question_dict["question"], type_mood, type_model)

            # Проверка вероятностей классов
            if max(prediction) < 0.2:
                answer = "Not Found"
            else:
                answer = predicted_class

            question_dict["answer"] = answer
            question_dict["question_time"] = datetime.now()
            question = QuestionOrm(**question_dict)
            session.add(question)

            # Проверка количества записей для email_user
            query = select(QuestionOrm).where(QuestionOrm.email_user == data.email_user)
            result = await session.execute(query)
            user_questions = result.scalars().all()
            if len(user_questions) > 10:
                # Удаление самой старой записи
                oldest_question = min(user_questions, key=lambda q: q.question_time)
                await session.delete(oldest_question)

            await session.flush()
            await session.commit()
            return question.id, question.answer

    @classmethod
    async def find_all(cls) -> list[SQuestion]:
        async with new_session() as session:
            query = select(QuestionOrm)
            result = await session.execute(query)
            question_models = result.scalars().all()
            question_schemas = [SQuestion.model_validate(question_model) for question_model in question_models]
            return question_schemas

    @classmethod
    async def find_by_email(cls, email_user: str) -> list[SQuestion]:
        async with new_session() as session:
            query = select(QuestionOrm).where(QuestionOrm.email_user == email_user)
            result = await session.execute(query)
            question_models = result.scalars().all()
            question_schemas = [SQuestion.model_validate(question_model) for question_model in question_models]
            return question_schemas