80 lines
2.5 KiB
Python
80 lines
2.5 KiB
Python
from fastapi import FastAPI, Depends, HTTPException
|
|
from sqlalchemy.orm import Session
|
|
from typing import List
|
|
from models import Laptop
|
|
from schemas import LaptopCreate, LaptopResponse
|
|
from database import SessionLocal, engine, Base
|
|
from pydantic import BaseModel
|
|
import joblib
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
# Загрузка модели и списка признаков
|
|
model = joblib.load('laptop_price_model.pkl')
|
|
feature_columns = joblib.load('feature_columns.pkl')
|
|
|
|
app = FastAPI()
|
|
|
|
# Определение Pydantic модели для входных данных
|
|
class LaptopData(BaseModel):
|
|
processor: str
|
|
ram: int
|
|
os: str
|
|
ssd: int
|
|
display: float
|
|
|
|
def get_db():
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@app.post("/laptops/", response_model=LaptopResponse)
|
|
def create_laptop(laptop: LaptopCreate, db: Session = Depends(get_db)):
|
|
db_laptop = Laptop(**laptop.dict())
|
|
db.add(db_laptop)
|
|
db.commit()
|
|
db.refresh(db_laptop)
|
|
return db_laptop
|
|
|
|
|
|
@app.get("/laptops/", response_model=List[LaptopResponse])
|
|
def read_laptops(skip: int = 0, limit: int = 10, db: Session = Depends(get_db)):
|
|
laptops = db.query(Laptop).offset(skip).limit(limit).all()
|
|
return laptops
|
|
|
|
|
|
@app.get("/laptops/{laptop_id}", response_model=LaptopResponse)
|
|
def read_laptop(laptop_id: int, db: Session = Depends(get_db)):
|
|
laptop = db.query(Laptop).filter(Laptop.id == laptop_id).first()
|
|
if laptop is None:
|
|
raise HTTPException(status_code=404, detail="Laptop not found")
|
|
return laptop
|
|
|
|
|
|
# Эндпоинт для предсказания цены
|
|
@app.post("/predict_price/")
|
|
def predict_price(data: LaptopData):
|
|
input_data = data.dict()
|
|
|
|
# Преобразование данных в DataFrame
|
|
input_df = pd.DataFrame([input_data])
|
|
|
|
# Применение One-Hot Encoding к категориальным признакам
|
|
input_df = pd.get_dummies(input_df, columns=['processor', 'os'], drop_first=True)
|
|
|
|
# Добавление отсутствующих признаков, если они есть
|
|
for col in feature_columns:
|
|
if col not in input_df.columns and col != 'price':
|
|
input_df[col] = 0
|
|
|
|
# Упорядочивание колонок согласно обучающей выборке
|
|
input_df = input_df[feature_columns]
|
|
|
|
# Предсказание цены
|
|
predicted_price = model.predict(input_df)[0]
|
|
|
|
return {"predicted_price": round(predicted_price, 2)}
|