price-builder-backend/main.py

80 lines
2.5 KiB
Python
Raw Permalink Normal View History

from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import List
from models import Laptop
from schemas import LaptopCreate, LaptopResponse
from database import SessionLocal, engine, Base
from pydantic import BaseModel
import joblib
import pandas as pd
import numpy as np
# Загрузка модели и списка признаков
model = joblib.load('laptop_price_model.pkl')
feature_columns = joblib.load('feature_columns.pkl')
app = FastAPI()
# Определение Pydantic модели для входных данных
class LaptopData(BaseModel):
processor: str
ram: int
os: str
ssd: int
display: float
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@app.post("/laptops/", response_model=LaptopResponse)
def create_laptop(laptop: LaptopCreate, db: Session = Depends(get_db)):
db_laptop = Laptop(**laptop.dict())
db.add(db_laptop)
db.commit()
db.refresh(db_laptop)
return db_laptop
@app.get("/laptops/", response_model=List[LaptopResponse])
def read_laptops(skip: int = 0, limit: int = 10, db: Session = Depends(get_db)):
laptops = db.query(Laptop).offset(skip).limit(limit).all()
return laptops
@app.get("/laptops/{laptop_id}", response_model=LaptopResponse)
def read_laptop(laptop_id: int, db: Session = Depends(get_db)):
laptop = db.query(Laptop).filter(Laptop.id == laptop_id).first()
if laptop is None:
raise HTTPException(status_code=404, detail="Laptop not found")
return laptop
# Эндпоинт для предсказания цены
@app.post("/predict_price/")
def predict_price(data: LaptopData):
input_data = data.dict()
# Преобразование данных в DataFrame
input_df = pd.DataFrame([input_data])
# Применение One-Hot Encoding к категориальным признакам
input_df = pd.get_dummies(input_df, columns=['processor', 'os'], drop_first=True)
# Добавление отсутствующих признаков, если они есть
for col in feature_columns:
if col not in input_df.columns and col != 'price':
input_df[col] = 0
# Упорядочивание колонок согласно обучающей выборке
input_df = input_df[feature_columns]
# Предсказание цены
predicted_price = model.predict(input_df)[0]
return {"predicted_price": round(predicted_price, 2)}