Compare commits
2 Commits
e83da8582d
...
28d22ea47e
Author | SHA1 | Date | |
---|---|---|---|
28d22ea47e | |||
6cbcb76867 |
@ -0,0 +1,20 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from schemas.schemas import LaptopCreate, LaptopResponse, PredictPriceResponse
|
||||
from services.service import LaptopService
|
||||
import os
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Инициализация сервиса
|
||||
MODEL_PATH = os.getenv("MODEL_PATH", "laptop_price_model.pkl")
|
||||
FEATURE_COLUMNS_PATH = os.getenv("FEATURE_COLUMNS_PATH", "feature_columns.pkl")
|
||||
laptop_service = LaptopService(model_path=MODEL_PATH, feature_columns_path=FEATURE_COLUMNS_PATH)
|
||||
|
||||
@router.post("/predict_price/", response_model=PredictPriceResponse)
|
||||
def predict_price(data: LaptopCreate):
|
||||
try:
|
||||
return laptop_service.predict_price(data.dict())
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
19
main.py
19
main.py
@ -1,16 +1,7 @@
|
||||
from fastapi import FastAPI, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from models.models import Laptop
|
||||
from schemas import LaptopCreate, LaptopResponse
|
||||
from database import SessionLocal, engine, Base
|
||||
from pydantic import BaseModel
|
||||
import joblib
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
# Загрузка модели и списка признаков
|
||||
model = joblib.load('laptop_price_model.pkl')
|
||||
feature_columns = joblib.load('feature_columns.pkl')
|
||||
from fastapi import FastAPI
|
||||
from controllers import controller
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Подключение маршрутов
|
||||
app.include_router(controller.router)
|
||||
|
@ -5,9 +5,8 @@ class Laptop(Base):
|
||||
__tablename__ = "laptops"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
title = Column(String, index=True)
|
||||
price = Column(Float)
|
||||
processor = Column(String)
|
||||
processor = Column(String, index=True)
|
||||
ram = Column(Integer)
|
||||
os = Column(String, index=True)
|
||||
ssd = Column(Integer)
|
||||
display = Column(Float)
|
||||
|
@ -7,7 +7,7 @@ import joblib
|
||||
import re
|
||||
|
||||
# Шаг 1: Загрузка данных
|
||||
df = pd.read_csv('laptops.csv')
|
||||
df = pd.read_csv('../laptops.csv')
|
||||
|
||||
# Шаг 2: Проверка и очистка имен столбцов
|
||||
print("Имена столбцов до очистки:")
|
||||
@ -134,10 +134,10 @@ for name, mdl in models.items():
|
||||
print(f"{name} - MAE: {mae}, RMSE: {rmse}, R²: {r2}")
|
||||
|
||||
# Шаг 12: Сохранение модели
|
||||
joblib.dump(model, 'laptop_price_model.pkl')
|
||||
joblib.dump(model, '../laptop_price_model.pkl')
|
||||
print("\nМодель сохранена как 'laptop_price_model.pkl'.")
|
||||
|
||||
# Дополнительно: Сохранение колонок, полученных после One-Hot Encoding, для использования в бэкенде
|
||||
feature_columns = X.columns.tolist()
|
||||
joblib.dump(feature_columns, 'feature_columns.pkl')
|
||||
joblib.dump(feature_columns, '../feature_columns.pkl')
|
||||
print("Сохранены названия признаков в 'feature_columns.pkl'.")
|
||||
|
@ -0,0 +1,40 @@
|
||||
import pandas as pd
|
||||
import joblib
|
||||
from typing import List, Dict
|
||||
from schemas.schemas import LaptopCreate, LaptopResponse, PredictPriceResponse
|
||||
|
||||
class LaptopService:
|
||||
def __init__(self, model_path: str, feature_columns_path: str):
|
||||
try:
|
||||
self.model = joblib.load(model_path)
|
||||
except FileNotFoundError:
|
||||
raise Exception(f"Model file not found at {model_path}")
|
||||
except Exception as e:
|
||||
raise Exception(f"Error loading model: {str(e)}")
|
||||
|
||||
try:
|
||||
self.feature_columns = joblib.load(feature_columns_path)
|
||||
except FileNotFoundError:
|
||||
raise Exception(f"Feature columns file not found at {feature_columns_path}")
|
||||
except Exception as e:
|
||||
raise Exception(f"Error loading feature columns: {str(e)}")
|
||||
|
||||
def predict_price(self, data: Dict[str, any]) -> PredictPriceResponse:
|
||||
# Преобразование данных в DataFrame
|
||||
input_df = pd.DataFrame([data])
|
||||
|
||||
# Применение One-Hot Encoding к категориальным признакам
|
||||
input_df = pd.get_dummies(input_df, columns=['processor', 'os'], drop_first=True)
|
||||
|
||||
# Добавление отсутствующих признаков, если они есть
|
||||
for col in self.feature_columns:
|
||||
if col not in input_df.columns and col != 'price':
|
||||
input_df[col] = 0
|
||||
|
||||
# Упорядочивание колонок согласно обучающей выборке
|
||||
input_df = input_df[self.feature_columns]
|
||||
|
||||
# Предсказание цены
|
||||
predicted_price = self.model.predict(input_df)[0]
|
||||
|
||||
return PredictPriceResponse(predicted_price=round(predicted_price, 2))
|
Loading…
Reference in New Issue
Block a user