com
This commit is contained in:
parent
7507a56237
commit
e83da8582d
65
main.py
65
main.py
@ -1,7 +1,7 @@
|
||||
from fastapi import FastAPI, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from models import Laptop
|
||||
from models.models import Laptop
|
||||
from schemas import LaptopCreate, LaptopResponse
|
||||
from database import SessionLocal, engine, Base
|
||||
from pydantic import BaseModel
|
||||
@ -14,66 +14,3 @@ model = joblib.load('laptop_price_model.pkl')
|
||||
feature_columns = joblib.load('feature_columns.pkl')
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Определение Pydantic модели для входных данных
|
||||
class LaptopData(BaseModel):
|
||||
processor: str
|
||||
ram: int
|
||||
os: str
|
||||
ssd: int
|
||||
display: float
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@app.post("/laptops/", response_model=LaptopResponse)
|
||||
def create_laptop(laptop: LaptopCreate, db: Session = Depends(get_db)):
|
||||
db_laptop = Laptop(**laptop.dict())
|
||||
db.add(db_laptop)
|
||||
db.commit()
|
||||
db.refresh(db_laptop)
|
||||
return db_laptop
|
||||
|
||||
|
||||
@app.get("/laptops/", response_model=List[LaptopResponse])
|
||||
def read_laptops(skip: int = 0, limit: int = 10, db: Session = Depends(get_db)):
|
||||
laptops = db.query(Laptop).offset(skip).limit(limit).all()
|
||||
return laptops
|
||||
|
||||
|
||||
@app.get("/laptops/{laptop_id}", response_model=LaptopResponse)
|
||||
def read_laptop(laptop_id: int, db: Session = Depends(get_db)):
|
||||
laptop = db.query(Laptop).filter(Laptop.id == laptop_id).first()
|
||||
if laptop is None:
|
||||
raise HTTPException(status_code=404, detail="Laptop not found")
|
||||
return laptop
|
||||
|
||||
|
||||
# Эндпоинт для предсказания цены
|
||||
@app.post("/predict_price/")
|
||||
def predict_price(data: LaptopData):
|
||||
input_data = data.dict()
|
||||
|
||||
# Преобразование данных в DataFrame
|
||||
input_df = pd.DataFrame([input_data])
|
||||
|
||||
# Применение One-Hot Encoding к категориальным признакам
|
||||
input_df = pd.get_dummies(input_df, columns=['processor', 'os'], drop_first=True)
|
||||
|
||||
# Добавление отсутствующих признаков, если они есть
|
||||
for col in feature_columns:
|
||||
if col not in input_df.columns and col != 'price':
|
||||
input_df[col] = 0
|
||||
|
||||
# Упорядочивание колонок согласно обучающей выборке
|
||||
input_df = input_df[feature_columns]
|
||||
|
||||
# Предсказание цены
|
||||
predicted_price = model.predict(input_df)[0]
|
||||
|
||||
return {"predicted_price": round(predicted_price, 2)}
|
||||
|
@ -0,0 +1,9 @@
|
||||
fastapi
|
||||
uvicorn
|
||||
sqlalchemy
|
||||
psycopg2-binary
|
||||
pydantic
|
||||
joblib
|
||||
pandas
|
||||
numpy
|
||||
python-dotenv
|
Loading…
Reference in New Issue
Block a user