На самом деле много чего добавил. Например все виды моделей, что если если меньше 0,2 вероятность то Not Found ну и вывод всех классов для Андрей
This commit is contained in:
parent
6ccd3343e9
commit
2787cf59ae
@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlalchemy import Column, DateTime, Enum
|
||||
from sqlalchemy.sql import func
|
||||
from datetime import datetime
|
||||
from enums import TypeMood, TypeModel
|
||||
|
||||
engine = create_async_engine("sqlite+aiosqlite:///questions.db")
|
||||
new_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||
@ -16,7 +17,8 @@ class QuestionOrm(Model):
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
email_user: Mapped[str]
|
||||
type_question: Mapped[bool]
|
||||
type_mood: Mapped[TypeMood] = mapped_column(Enum(TypeMood))
|
||||
type_model: Mapped[TypeModel] = mapped_column(Enum(TypeModel))
|
||||
question: Mapped[str]
|
||||
answer: Mapped[Optional[str]]
|
||||
question_time: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
|
10
enums.py
Normal file
10
enums.py
Normal file
@ -0,0 +1,10 @@
|
||||
from enum import Enum
|
||||
|
||||
class TypeMood(str, Enum):
|
||||
POSITIVE = "Posititve"
|
||||
NEGATIVE = "Negative"
|
||||
|
||||
class TypeModel(str, Enum):
|
||||
LSTM = "LSTM"
|
||||
GRU = "GRU"
|
||||
CNN = "CNN"
|
41
model.py
41
model.py
@ -4,8 +4,17 @@ import tensorflow as tf
|
||||
from keras.src.legacy.preprocessing.text import Tokenizer
|
||||
from keras.src.utils import pad_sequences
|
||||
|
||||
from enums import TypeMood, TypeModel
|
||||
|
||||
# Загрузка модели
|
||||
model = tf.keras.models.load_model('.//neural_network/models/model/best_model_lstm_negative.keras')
|
||||
model_lstm_negative = tf.keras.models.load_model('.//neural_network/models/model/best_model_lstm_negative.keras')
|
||||
model_gru_negative = tf.keras.models.load_model('.//neural_network/models/model/best_model_gru_negative.keras')
|
||||
model_cnn_negative = tf.keras.models.load_model('.//neural_network/models/model/best_model_cnn_negative.keras')
|
||||
|
||||
model_lstm_positive = tf.keras.models.load_model('.//neural_network/models/model/best_model_lstm_positive.keras')
|
||||
model_gru_positive = tf.keras.models.load_model('.//neural_network/models/model/best_model_gru_positive.keras')
|
||||
model_cnn_positive= tf.keras.models.load_model('.//neural_network/models/model/best_model_cnn_positive.keras')
|
||||
|
||||
|
||||
# Загрузка токенизатора
|
||||
with open('.//neural_network/tokenization/tokenizer_negative.pickle', 'rb') as handle:
|
||||
@ -18,21 +27,31 @@ with open('.//neural_network/classification/class_names_negative.txt', 'r', enco
|
||||
def preprocess_text(text: str):
|
||||
# Токенизация текста
|
||||
sequences = tokenizer.texts_to_sequences([text])
|
||||
# Преобразование последовательностей в фиксированной длины
|
||||
# Преобразование последовательностей в фиксированной длине
|
||||
padded_sequences = pad_sequences(sequences, maxlen=90) # 90 - длина последовательности, используемая при обучении
|
||||
return padded_sequences
|
||||
|
||||
def predict_answer(question: str) -> str:
|
||||
def predict_answer(question: str, type_mood: TypeMood, type_model: TypeModel) -> str:
|
||||
if type_model == TypeModel.LSTM and type_mood == TypeMood.NEGATIVE:
|
||||
model = model_lstm_negative
|
||||
elif type_model == TypeModel.LSTM and type_mood == TypeMood.POSITIVE:
|
||||
model = model_lstm_positive
|
||||
elif type_model == TypeModel.GRU and type_mood == TypeMood.NEGATIVE:
|
||||
model = model_gru_negative
|
||||
elif type_model == TypeModel.GRU and type_mood == TypeMood.POSITIVE:
|
||||
model = model_gru_positive
|
||||
elif type_model == TypeModel.CNN and type_mood == TypeMood.NEGATIVE:
|
||||
model = model_cnn_negative
|
||||
elif type_model == TypeModel.CNN and type_mood == TypeMood.POSITIVE:
|
||||
model = model_cnn_positive
|
||||
else:
|
||||
raise ValueError("Unsupported model type")
|
||||
|
||||
# Предобработка вопроса
|
||||
print("Вопрос:", question)
|
||||
input_data = preprocess_text(question)
|
||||
print("Предобработанные данные:", input_data)
|
||||
# Предсказание
|
||||
prediction = model.predict(input_data)
|
||||
print("Предсказание:", prediction)
|
||||
# Определение индекса класса с наибольшей вероятностью
|
||||
predicted_index = np.argmax(prediction[0])
|
||||
prediction = model.predict(input_data)[0]
|
||||
# Получение имени класса
|
||||
predicted_index = np.argmax(prediction)
|
||||
predicted_class = class_names[predicted_index]
|
||||
print("Предсказанный класс:", predicted_class)
|
||||
return predicted_class # Возвращаем имя предсказанного класса
|
||||
return predicted_class, prediction
|
@ -3,17 +3,24 @@ from sqlalchemy.orm import joinedload
|
||||
from datetime import datetime
|
||||
|
||||
from database import new_session, QuestionOrm
|
||||
from enums import TypeMood, TypeModel
|
||||
from schemas import SQuestionAdd, SQuestion
|
||||
from model import predict_answer
|
||||
|
||||
class QuestionRepository:
|
||||
@classmethod
|
||||
async def add_one(cls, data: SQuestionAdd) -> int:
|
||||
async def add_one(cls, data: SQuestionAdd, type_mood: TypeMood, type_model: TypeModel) -> int:
|
||||
async with new_session() as session:
|
||||
question_dict = data.model_dump()
|
||||
|
||||
# Предсказание ответа с помощью модели
|
||||
answer = predict_answer(question_dict["question"])
|
||||
predicted_class, prediction = predict_answer(question_dict["question"], type_mood, type_model)
|
||||
|
||||
# Проверка вероятностей классов
|
||||
if max(prediction) < 0.2:
|
||||
answer = "Not Found"
|
||||
else:
|
||||
answer = predicted_class
|
||||
|
||||
question_dict["answer"] = answer
|
||||
question_dict["question_time"] = datetime.now()
|
||||
@ -41,3 +48,12 @@ class QuestionRepository:
|
||||
question_models = result.scalars().all()
|
||||
question_schemas = [SQuestion.model_validate(question_model) for question_model in question_models]
|
||||
return question_schemas
|
||||
|
||||
@classmethod
|
||||
async def find_by_email(cls, email_user: str) -> list[SQuestion]:
|
||||
async with new_session() as session:
|
||||
query = select(QuestionOrm).where(QuestionOrm.email_user == email_user)
|
||||
result = await session.execute(query)
|
||||
question_models = result.scalars().all()
|
||||
question_schemas = [SQuestion.model_validate(question_model) for question_model in question_models]
|
||||
return question_schemas
|
||||
|
24
router.py
24
router.py
@ -1,6 +1,8 @@
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends
|
||||
from typing import List
|
||||
|
||||
from enums import TypeMood, TypeModel
|
||||
from repository import QuestionRepository
|
||||
from schemas import SQuestionAdd, SQuestion, SQuestionId
|
||||
|
||||
@ -9,14 +11,34 @@ router = APIRouter(
|
||||
tags=["Questions"],
|
||||
)
|
||||
|
||||
@router.get("/class_negative")
|
||||
async def get_class_names() -> List[str]:
|
||||
with open(".//neural_network/classification/class_names_negative.txt", "r", encoding="utf-8") as file:
|
||||
class_names = [line.strip() for line in file.readlines()]
|
||||
return class_names
|
||||
|
||||
@router.get("/class_positive")
|
||||
async def get_class_names() -> List[str]:
|
||||
with open(".//neural_network/classification/class_names_positive.txt", "r", encoding="utf-8") as file:
|
||||
class_names = [line.strip() for line in file.readlines()]
|
||||
return class_names
|
||||
|
||||
@router.post("")
|
||||
async def add_question(
|
||||
question: Annotated[SQuestionAdd, Depends()],
|
||||
type_mood: TypeMood, # Добавлен параметр type_mood
|
||||
type_model: TypeModel, # Добавлен параметр type_model
|
||||
) -> SQuestionId:
|
||||
question_id = await QuestionRepository.add_one(question)
|
||||
question_id = await QuestionRepository.add_one(question, type_mood, type_model) # Передача параметров type_mood и type_model
|
||||
return {"ok": True, "question_id": question_id}
|
||||
|
||||
@router.get("")
|
||||
async def get_questions() -> list[SQuestion]:
|
||||
questions = await QuestionRepository.find_all()
|
||||
return questions
|
||||
|
||||
|
||||
@router.get("/{email_user}")
|
||||
async def get_questions_by_email(email_user: str) -> list[SQuestion]:
|
||||
questions = await QuestionRepository.find_by_email(email_user)
|
||||
return questions
|
||||
|
@ -1,10 +1,12 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from datetime import datetime
|
||||
from enums import TypeMood, TypeModel
|
||||
|
||||
class SQuestionAdd(BaseModel):
|
||||
email_user: str
|
||||
type_question: bool
|
||||
type_mood: TypeMood
|
||||
type_model: TypeModel
|
||||
question: str
|
||||
|
||||
class SQuestion(SQuestionAdd):
|
||||
|
Loading…
Reference in New Issue
Block a user