60 lines
2.5 KiB
Python
60 lines
2.5 KiB
Python
from sqlalchemy import select, delete
|
||
from sqlalchemy.orm import joinedload
|
||
from datetime import datetime
|
||
|
||
from database import new_session, QuestionOrm
|
||
from enums import TypeMood, TypeModel
|
||
from schemas import SQuestionAdd, SQuestion
|
||
from model import predict_answer
|
||
|
||
class QuestionRepository:
|
||
@classmethod
|
||
async def add_one(cls, data: SQuestionAdd, type_mood: TypeMood, type_model: TypeModel) -> int:
|
||
async with new_session() as session:
|
||
question_dict = data.model_dump()
|
||
|
||
# Предсказание ответа с помощью модели
|
||
predicted_class, prediction = predict_answer(question_dict["question"], type_mood, type_model)
|
||
|
||
# Проверка вероятностей классов
|
||
if max(prediction) < 0.2:
|
||
answer = "Not Found"
|
||
else:
|
||
answer = predicted_class
|
||
|
||
question_dict["answer"] = answer
|
||
question_dict["question_time"] = datetime.now()
|
||
question = QuestionOrm(**question_dict)
|
||
session.add(question)
|
||
|
||
# Проверка количества записей для email_user
|
||
query = select(QuestionOrm).where(QuestionOrm.email_user == data.email_user)
|
||
result = await session.execute(query)
|
||
user_questions = result.scalars().all()
|
||
if len(user_questions) > 10:
|
||
# Удаление самой старой записи
|
||
oldest_question = min(user_questions, key=lambda q: q.question_time)
|
||
await session.delete(oldest_question)
|
||
|
||
await session.flush()
|
||
await session.commit()
|
||
return question.id
|
||
|
||
@classmethod
|
||
async def find_all(cls) -> list[SQuestion]:
|
||
async with new_session() as session:
|
||
query = select(QuestionOrm)
|
||
result = await session.execute(query)
|
||
question_models = result.scalars().all()
|
||
question_schemas = [SQuestion.model_validate(question_model) for question_model in question_models]
|
||
return question_schemas
|
||
|
||
@classmethod
|
||
async def find_by_email(cls, email_user: str) -> list[SQuestion]:
|
||
async with new_session() as session:
|
||
query = select(QuestionOrm).where(QuestionOrm.email_user == email_user)
|
||
result = await session.execute(query)
|
||
question_models = result.scalars().all()
|
||
question_schemas = [SQuestion.model_validate(question_model) for question_model in question_models]
|
||
return question_schemas
|