com
This commit is contained in:
parent
7507a56237
commit
e83da8582d
65
main.py
65
main.py
@ -1,7 +1,7 @@
|
|||||||
from fastapi import FastAPI, Depends, HTTPException
|
from fastapi import FastAPI, Depends, HTTPException
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import List
|
from typing import List
|
||||||
from models import Laptop
|
from models.models import Laptop
|
||||||
from schemas import LaptopCreate, LaptopResponse
|
from schemas import LaptopCreate, LaptopResponse
|
||||||
from database import SessionLocal, engine, Base
|
from database import SessionLocal, engine, Base
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -14,66 +14,3 @@ model = joblib.load('laptop_price_model.pkl')
|
|||||||
feature_columns = joblib.load('feature_columns.pkl')
|
feature_columns = joblib.load('feature_columns.pkl')
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
# Определение Pydantic модели для входных данных
|
|
||||||
class LaptopData(BaseModel):
|
|
||||||
processor: str
|
|
||||||
ram: int
|
|
||||||
os: str
|
|
||||||
ssd: int
|
|
||||||
display: float
|
|
||||||
|
|
||||||
def get_db():
|
|
||||||
db = SessionLocal()
|
|
||||||
try:
|
|
||||||
yield db
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/laptops/", response_model=LaptopResponse)
|
|
||||||
def create_laptop(laptop: LaptopCreate, db: Session = Depends(get_db)):
|
|
||||||
db_laptop = Laptop(**laptop.dict())
|
|
||||||
db.add(db_laptop)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(db_laptop)
|
|
||||||
return db_laptop
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/laptops/", response_model=List[LaptopResponse])
|
|
||||||
def read_laptops(skip: int = 0, limit: int = 10, db: Session = Depends(get_db)):
|
|
||||||
laptops = db.query(Laptop).offset(skip).limit(limit).all()
|
|
||||||
return laptops
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/laptops/{laptop_id}", response_model=LaptopResponse)
|
|
||||||
def read_laptop(laptop_id: int, db: Session = Depends(get_db)):
|
|
||||||
laptop = db.query(Laptop).filter(Laptop.id == laptop_id).first()
|
|
||||||
if laptop is None:
|
|
||||||
raise HTTPException(status_code=404, detail="Laptop not found")
|
|
||||||
return laptop
|
|
||||||
|
|
||||||
|
|
||||||
# Эндпоинт для предсказания цены
|
|
||||||
@app.post("/predict_price/")
|
|
||||||
def predict_price(data: LaptopData):
|
|
||||||
input_data = data.dict()
|
|
||||||
|
|
||||||
# Преобразование данных в DataFrame
|
|
||||||
input_df = pd.DataFrame([input_data])
|
|
||||||
|
|
||||||
# Применение One-Hot Encoding к категориальным признакам
|
|
||||||
input_df = pd.get_dummies(input_df, columns=['processor', 'os'], drop_first=True)
|
|
||||||
|
|
||||||
# Добавление отсутствующих признаков, если они есть
|
|
||||||
for col in feature_columns:
|
|
||||||
if col not in input_df.columns and col != 'price':
|
|
||||||
input_df[col] = 0
|
|
||||||
|
|
||||||
# Упорядочивание колонок согласно обучающей выборке
|
|
||||||
input_df = input_df[feature_columns]
|
|
||||||
|
|
||||||
# Предсказание цены
|
|
||||||
predicted_price = model.predict(input_df)[0]
|
|
||||||
|
|
||||||
return {"predicted_price": round(predicted_price, 2)}
|
|
||||||
|
@ -0,0 +1,9 @@
|
|||||||
|
fastapi
|
||||||
|
uvicorn
|
||||||
|
sqlalchemy
|
||||||
|
psycopg2-binary
|
||||||
|
pydantic
|
||||||
|
joblib
|
||||||
|
pandas
|
||||||
|
numpy
|
||||||
|
python-dotenv
|
Loading…
Reference in New Issue
Block a user