Compare commits

..

2 Commits

Author SHA1 Message Date
28d22ea47e merge comp 2024-10-13 16:56:01 +04:00
6cbcb76867 Fixed. 2024-10-13 16:53:07 +04:00
5 changed files with 70 additions and 20 deletions

View File

@ -0,0 +1,20 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import List
from schemas.schemas import LaptopCreate, LaptopResponse, PredictPriceResponse
from services.service import LaptopService
import os
router = APIRouter()
# Инициализация сервиса
MODEL_PATH = os.getenv("MODEL_PATH", "laptop_price_model.pkl")
FEATURE_COLUMNS_PATH = os.getenv("FEATURE_COLUMNS_PATH", "feature_columns.pkl")
laptop_service = LaptopService(model_path=MODEL_PATH, feature_columns_path=FEATURE_COLUMNS_PATH)
@router.post("/predict_price/", response_model=PredictPriceResponse)
def predict_price(data: LaptopCreate):
try:
return laptop_service.predict_price(data.dict())
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

19
main.py
View File

@ -1,16 +1,7 @@
from fastapi import FastAPI, Depends, HTTPException from fastapi import FastAPI
from sqlalchemy.orm import Session from controllers import controller
from typing import List
from models.models import Laptop
from schemas import LaptopCreate, LaptopResponse
from database import SessionLocal, engine, Base
from pydantic import BaseModel
import joblib
import pandas as pd
import numpy as np
# Загрузка модели и списка признаков
model = joblib.load('laptop_price_model.pkl')
feature_columns = joblib.load('feature_columns.pkl')
app = FastAPI() app = FastAPI()
# Подключение маршрутов
app.include_router(controller.router)

View File

@ -5,9 +5,8 @@ class Laptop(Base):
__tablename__ = "laptops" __tablename__ = "laptops"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
title = Column(String, index=True) processor = Column(String, index=True)
price = Column(Float)
processor = Column(String)
ram = Column(Integer) ram = Column(Integer)
os = Column(String, index=True)
ssd = Column(Integer) ssd = Column(Integer)
display = Column(Float) display = Column(Float)

View File

@ -7,7 +7,7 @@ import joblib
import re import re
# Шаг 1: Загрузка данных # Шаг 1: Загрузка данных
df = pd.read_csv('laptops.csv') df = pd.read_csv('../laptops.csv')
# Шаг 2: Проверка и очистка имен столбцов # Шаг 2: Проверка и очистка имен столбцов
print("Имена столбцов до очистки:") print("Имена столбцов до очистки:")
@ -134,10 +134,10 @@ for name, mdl in models.items():
print(f"{name} - MAE: {mae}, RMSE: {rmse}, R²: {r2}") print(f"{name} - MAE: {mae}, RMSE: {rmse}, R²: {r2}")
# Шаг 12: Сохранение модели # Шаг 12: Сохранение модели
joblib.dump(model, 'laptop_price_model.pkl') joblib.dump(model, '../laptop_price_model.pkl')
print("\nМодель сохранена как 'laptop_price_model.pkl'.") print("\nМодель сохранена как 'laptop_price_model.pkl'.")
# Дополнительно: Сохранение колонок, полученных после One-Hot Encoding, для использования в бэкенде # Дополнительно: Сохранение колонок, полученных после One-Hot Encoding, для использования в бэкенде
feature_columns = X.columns.tolist() feature_columns = X.columns.tolist()
joblib.dump(feature_columns, 'feature_columns.pkl') joblib.dump(feature_columns, '../feature_columns.pkl')
print("Сохранены названия признаков в 'feature_columns.pkl'.") print("Сохранены названия признаков в 'feature_columns.pkl'.")

View File

@ -0,0 +1,40 @@
import pandas as pd
import joblib
from typing import List, Dict
from schemas.schemas import LaptopCreate, LaptopResponse, PredictPriceResponse
class LaptopService:
def __init__(self, model_path: str, feature_columns_path: str):
try:
self.model = joblib.load(model_path)
except FileNotFoundError:
raise Exception(f"Model file not found at {model_path}")
except Exception as e:
raise Exception(f"Error loading model: {str(e)}")
try:
self.feature_columns = joblib.load(feature_columns_path)
except FileNotFoundError:
raise Exception(f"Feature columns file not found at {feature_columns_path}")
except Exception as e:
raise Exception(f"Error loading feature columns: {str(e)}")
def predict_price(self, data: Dict[str, any]) -> PredictPriceResponse:
# Преобразование данных в DataFrame
input_df = pd.DataFrame([data])
# Применение One-Hot Encoding к категориальным признакам
input_df = pd.get_dummies(input_df, columns=['processor', 'os'], drop_first=True)
# Добавление отсутствующих признаков, если они есть
for col in self.feature_columns:
if col not in input_df.columns and col != 'price':
input_df[col] = 0
# Упорядочивание колонок согласно обучающей выборке
input_df = input_df[self.feature_columns]
# Предсказание цены
predicted_price = self.model.predict(input_df)[0]
return PredictPriceResponse(predicted_price=round(predicted_price, 2))