diff --git a/controllers/controller.py b/controllers/controller.py index e69de29..f52e6a6 100644 --- a/controllers/controller.py +++ b/controllers/controller.py @@ -0,0 +1,20 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session +from typing import List +from schemas.schemas import LaptopCreate, LaptopResponse, PredictPriceResponse +from services.service import LaptopService +import os + +router = APIRouter() + +# Инициализация сервиса +MODEL_PATH = os.getenv("MODEL_PATH", "laptop_price_model.pkl") +FEATURE_COLUMNS_PATH = os.getenv("FEATURE_COLUMNS_PATH", "feature_columns.pkl") +laptop_service = LaptopService(model_path=MODEL_PATH, feature_columns_path=FEATURE_COLUMNS_PATH) + +@router.post("/predict_price/", response_model=PredictPriceResponse) +def predict_price(data: LaptopCreate): + try: + return laptop_service.predict_price(data.dict()) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) \ No newline at end of file diff --git a/main.py b/main.py index fa4cae6..a980e41 100644 --- a/main.py +++ b/main.py @@ -1,16 +1,7 @@ -from fastapi import FastAPI, Depends, HTTPException -from sqlalchemy.orm import Session -from typing import List -from models.models import Laptop -from schemas import LaptopCreate, LaptopResponse -from database import SessionLocal, engine, Base -from pydantic import BaseModel -import joblib -import pandas as pd -import numpy as np - -# Загрузка модели и списка признаков -model = joblib.load('laptop_price_model.pkl') -feature_columns = joblib.load('feature_columns.pkl') +from fastapi import FastAPI +from controllers import controller app = FastAPI() + +# Подключение маршрутов +app.include_router(controller.router) diff --git a/models/models.py b/models/models.py index 7a507a1..1845e06 100644 --- a/models/models.py +++ b/models/models.py @@ -5,9 +5,8 @@ class Laptop(Base): __tablename__ = "laptops" id = Column(Integer, primary_key=True, index=True) - title = Column(String, index=True) - price = Column(Float) - processor = Column(String) + processor = Column(String, index=True) ram = Column(Integer) + os = Column(String, index=True) ssd = Column(Integer) display = Column(Float) diff --git a/services/modelBuilder.py b/services/modelBuilder.py index a0488a1..60c56bd 100644 --- a/services/modelBuilder.py +++ b/services/modelBuilder.py @@ -7,7 +7,7 @@ import joblib import re # Шаг 1: Загрузка данных -df = pd.read_csv('laptops.csv') +df = pd.read_csv('../laptops.csv') # Шаг 2: Проверка и очистка имен столбцов print("Имена столбцов до очистки:") @@ -134,10 +134,10 @@ for name, mdl in models.items(): print(f"{name} - MAE: {mae}, RMSE: {rmse}, R²: {r2}") # Шаг 12: Сохранение модели -joblib.dump(model, 'laptop_price_model.pkl') +joblib.dump(model, '../laptop_price_model.pkl') print("\nМодель сохранена как 'laptop_price_model.pkl'.") # Дополнительно: Сохранение колонок, полученных после One-Hot Encoding, для использования в бэкенде feature_columns = X.columns.tolist() -joblib.dump(feature_columns, 'feature_columns.pkl') +joblib.dump(feature_columns, '../feature_columns.pkl') print("Сохранены названия признаков в 'feature_columns.pkl'.") diff --git a/services/service.py b/services/service.py index e69de29..c71f023 100644 --- a/services/service.py +++ b/services/service.py @@ -0,0 +1,40 @@ +import pandas as pd +import joblib +from typing import List, Dict +from schemas.schemas import LaptopCreate, LaptopResponse, PredictPriceResponse + +class LaptopService: + def __init__(self, model_path: str, feature_columns_path: str): + try: + self.model = joblib.load(model_path) + except FileNotFoundError: + raise Exception(f"Model file not found at {model_path}") + except Exception as e: + raise Exception(f"Error loading model: {str(e)}") + + try: + self.feature_columns = joblib.load(feature_columns_path) + except FileNotFoundError: + raise Exception(f"Feature columns file not found at {feature_columns_path}") + except Exception as e: + raise Exception(f"Error loading feature columns: {str(e)}") + + def predict_price(self, data: Dict[str, any]) -> PredictPriceResponse: + # Преобразование данных в DataFrame + input_df = pd.DataFrame([data]) + + # Применение One-Hot Encoding к категориальным признакам + input_df = pd.get_dummies(input_df, columns=['processor', 'os'], drop_first=True) + + # Добавление отсутствующих признаков, если они есть + for col in self.feature_columns: + if col not in input_df.columns and col != 'price': + input_df[col] = 0 + + # Упорядочивание колонок согласно обучающей выборке + input_df = input_df[self.feature_columns] + + # Предсказание цены + predicted_price = self.model.predict(input_df)[0] + + return PredictPriceResponse(predicted_price=round(predicted_price, 2))