2024-06-08 01:22:35 +04:00
|
|
|
|
from sqlalchemy import select, delete, func
|
2024-06-02 15:51:14 +04:00
|
|
|
|
from sqlalchemy.orm import joinedload
|
|
|
|
|
from datetime import datetime
|
2024-06-07 16:51:16 +04:00
|
|
|
|
import urllib.parse
|
2024-05-25 16:11:52 +04:00
|
|
|
|
|
2024-05-25 16:51:34 +04:00
|
|
|
|
from database import new_session, QuestionOrm
|
2024-06-02 16:38:08 +04:00
|
|
|
|
from enums import TypeMood, TypeModel
|
2024-05-25 16:51:34 +04:00
|
|
|
|
from schemas import SQuestionAdd, SQuestion
|
2024-05-25 17:38:24 +04:00
|
|
|
|
from model import predict_answer
|
|
|
|
|
|
2024-05-25 16:51:34 +04:00
|
|
|
|
class QuestionRepository:
|
2024-05-25 16:11:52 +04:00
|
|
|
|
@classmethod
|
2024-06-02 16:38:08 +04:00
|
|
|
|
async def add_one(cls, data: SQuestionAdd, type_mood: TypeMood, type_model: TypeModel) -> int:
|
2024-05-25 16:11:52 +04:00
|
|
|
|
async with new_session() as session:
|
2024-05-25 16:51:34 +04:00
|
|
|
|
question_dict = data.model_dump()
|
2024-05-25 17:38:24 +04:00
|
|
|
|
|
2024-06-07 16:51:16 +04:00
|
|
|
|
# Декодирование URL-кодированных параметров
|
|
|
|
|
decoded_question = urllib.parse.unquote(question_dict["question"])
|
|
|
|
|
decoded_mood = urllib.parse.unquote(type_mood.value)
|
|
|
|
|
decoded_model = urllib.parse.unquote(type_model.value)
|
|
|
|
|
|
2024-05-25 17:38:24 +04:00
|
|
|
|
# Предсказание ответа с помощью модели
|
2024-06-07 16:51:16 +04:00
|
|
|
|
predicted_class, prediction = predict_answer(decoded_question, decoded_mood, decoded_model)
|
2024-06-02 16:38:08 +04:00
|
|
|
|
|
|
|
|
|
# Проверка вероятностей классов
|
|
|
|
|
if max(prediction) < 0.2:
|
|
|
|
|
answer = "Not Found"
|
|
|
|
|
else:
|
|
|
|
|
answer = predicted_class
|
2024-05-25 17:38:24 +04:00
|
|
|
|
|
2024-06-07 16:51:16 +04:00
|
|
|
|
# Обновление декодированных значений в словаре
|
|
|
|
|
question_dict["question"] = decoded_question
|
|
|
|
|
question_dict["type_mood"] = decoded_mood
|
|
|
|
|
question_dict["type_model"] = decoded_model
|
2024-05-25 17:38:24 +04:00
|
|
|
|
question_dict["answer"] = answer
|
2024-06-02 15:51:14 +04:00
|
|
|
|
question_dict["question_time"] = datetime.now()
|
2024-06-07 16:51:16 +04:00
|
|
|
|
|
2024-05-25 16:51:34 +04:00
|
|
|
|
question = QuestionOrm(**question_dict)
|
|
|
|
|
session.add(question)
|
2024-06-02 15:51:14 +04:00
|
|
|
|
|
|
|
|
|
# Проверка количества записей для email_user
|
|
|
|
|
query = select(QuestionOrm).where(QuestionOrm.email_user == data.email_user)
|
|
|
|
|
result = await session.execute(query)
|
|
|
|
|
user_questions = result.scalars().all()
|
2024-06-08 00:41:50 +04:00
|
|
|
|
if len(user_questions) > 100:
|
2024-06-02 15:51:14 +04:00
|
|
|
|
# Удаление самой старой записи
|
|
|
|
|
oldest_question = min(user_questions, key=lambda q: q.question_time)
|
|
|
|
|
await session.delete(oldest_question)
|
|
|
|
|
|
2024-05-25 16:11:52 +04:00
|
|
|
|
await session.flush()
|
|
|
|
|
await session.commit()
|
2024-06-03 20:05:19 +04:00
|
|
|
|
return question.id, question.answer
|
2024-05-25 16:11:52 +04:00
|
|
|
|
|
|
|
|
|
@classmethod
|
2024-05-25 16:51:34 +04:00
|
|
|
|
async def find_all(cls) -> list[SQuestion]:
|
2024-05-25 16:11:52 +04:00
|
|
|
|
async with new_session() as session:
|
2024-05-25 16:51:34 +04:00
|
|
|
|
query = select(QuestionOrm)
|
2024-05-25 16:11:52 +04:00
|
|
|
|
result = await session.execute(query)
|
2024-05-25 16:51:34 +04:00
|
|
|
|
question_models = result.scalars().all()
|
|
|
|
|
question_schemas = [SQuestion.model_validate(question_model) for question_model in question_models]
|
|
|
|
|
return question_schemas
|
2024-06-02 16:38:08 +04:00
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def find_by_email(cls, email_user: str) -> list[SQuestion]:
|
|
|
|
|
async with new_session() as session:
|
|
|
|
|
query = select(QuestionOrm).where(QuestionOrm.email_user == email_user)
|
|
|
|
|
result = await session.execute(query)
|
|
|
|
|
question_models = result.scalars().all()
|
|
|
|
|
question_schemas = [SQuestion.model_validate(question_model) for question_model in question_models]
|
|
|
|
|
return question_schemas
|
2024-06-08 01:22:35 +04:00
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
async def get_class_statistics() -> dict:
|
|
|
|
|
async with new_session() as session:
|
|
|
|
|
query = select(QuestionOrm.answer, func.count(QuestionOrm.id)).group_by(QuestionOrm.answer)
|
|
|
|
|
result = await session.execute(query)
|
|
|
|
|
statistics = {answer: count for answer, count in result.fetchall()}
|
|
|
|
|
return statistics
|