This commit is contained in:
maxnes3 2024-10-13 16:54:23 +04:00
parent 7507a56237
commit e83da8582d
2 changed files with 10 additions and 64 deletions

65
main.py
View File

@ -1,7 +1,7 @@
from fastapi import FastAPI, Depends, HTTPException from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List from typing import List
from models import Laptop from models.models import Laptop
from schemas import LaptopCreate, LaptopResponse from schemas import LaptopCreate, LaptopResponse
from database import SessionLocal, engine, Base from database import SessionLocal, engine, Base
from pydantic import BaseModel from pydantic import BaseModel
@ -14,66 +14,3 @@ model = joblib.load('laptop_price_model.pkl')
feature_columns = joblib.load('feature_columns.pkl') feature_columns = joblib.load('feature_columns.pkl')
app = FastAPI() app = FastAPI()
# Определение Pydantic модели для входных данных
class LaptopData(BaseModel):
processor: str
ram: int
os: str
ssd: int
display: float
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@app.post("/laptops/", response_model=LaptopResponse)
def create_laptop(laptop: LaptopCreate, db: Session = Depends(get_db)):
db_laptop = Laptop(**laptop.dict())
db.add(db_laptop)
db.commit()
db.refresh(db_laptop)
return db_laptop
@app.get("/laptops/", response_model=List[LaptopResponse])
def read_laptops(skip: int = 0, limit: int = 10, db: Session = Depends(get_db)):
laptops = db.query(Laptop).offset(skip).limit(limit).all()
return laptops
@app.get("/laptops/{laptop_id}", response_model=LaptopResponse)
def read_laptop(laptop_id: int, db: Session = Depends(get_db)):
laptop = db.query(Laptop).filter(Laptop.id == laptop_id).first()
if laptop is None:
raise HTTPException(status_code=404, detail="Laptop not found")
return laptop
# Эндпоинт для предсказания цены
@app.post("/predict_price/")
def predict_price(data: LaptopData):
input_data = data.dict()
# Преобразование данных в DataFrame
input_df = pd.DataFrame([input_data])
# Применение One-Hot Encoding к категориальным признакам
input_df = pd.get_dummies(input_df, columns=['processor', 'os'], drop_first=True)
# Добавление отсутствующих признаков, если они есть
for col in feature_columns:
if col not in input_df.columns and col != 'price':
input_df[col] = 0
# Упорядочивание колонок согласно обучающей выборке
input_df = input_df[feature_columns]
# Предсказание цены
predicted_price = model.predict(input_df)[0]
return {"predicted_price": round(predicted_price, 2)}

View File

@ -0,0 +1,9 @@
fastapi
uvicorn
sqlalchemy
psycopg2-binary
pydantic
joblib
pandas
numpy
python-dotenv