This commit is contained in:
Артём Алейкин 2024-10-13 16:53:07 +04:00
parent 847bb0694e
commit 6cbcb76867
7 changed files with 81 additions and 82 deletions

View File

@ -1,4 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (venv)" project-jdk-type="Python SDK" />
<component name="PyCharmProfessionalAdvertiser">
<option name="shown" value="true" />
</component>
</project>

View File

@ -0,0 +1,20 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import List
from schemas.schemas import LaptopCreate, LaptopResponse, PredictPriceResponse
from services.service import LaptopService
import os
router = APIRouter()
# Инициализация сервиса
MODEL_PATH = os.getenv("MODEL_PATH", "laptop_price_model.pkl")
FEATURE_COLUMNS_PATH = os.getenv("FEATURE_COLUMNS_PATH", "feature_columns.pkl")
laptop_service = LaptopService(model_path=MODEL_PATH, feature_columns_path=FEATURE_COLUMNS_PATH)
@router.post("/predict_price/", response_model=PredictPriceResponse)
def predict_price(data: LaptopCreate):
try:
return laptop_service.predict_price(data.dict())
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

80
main.py
View File

@ -1,79 +1,7 @@
from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import List
from models import Laptop
from schemas import LaptopCreate, LaptopResponse
from database import SessionLocal, engine, Base
from pydantic import BaseModel
import joblib
import pandas as pd
import numpy as np
# Загрузка модели и списка признаков
model = joblib.load('laptop_price_model.pkl')
feature_columns = joblib.load('feature_columns.pkl')
from fastapi import FastAPI
from controllers import controller
app = FastAPI()
# Определение Pydantic модели для входных данных
class LaptopData(BaseModel):
processor: str
ram: int
os: str
ssd: int
display: float
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@app.post("/laptops/", response_model=LaptopResponse)
def create_laptop(laptop: LaptopCreate, db: Session = Depends(get_db)):
db_laptop = Laptop(**laptop.dict())
db.add(db_laptop)
db.commit()
db.refresh(db_laptop)
return db_laptop
@app.get("/laptops/", response_model=List[LaptopResponse])
def read_laptops(skip: int = 0, limit: int = 10, db: Session = Depends(get_db)):
laptops = db.query(Laptop).offset(skip).limit(limit).all()
return laptops
@app.get("/laptops/{laptop_id}", response_model=LaptopResponse)
def read_laptop(laptop_id: int, db: Session = Depends(get_db)):
laptop = db.query(Laptop).filter(Laptop.id == laptop_id).first()
if laptop is None:
raise HTTPException(status_code=404, detail="Laptop not found")
return laptop
# Эндпоинт для предсказания цены
@app.post("/predict_price/")
def predict_price(data: LaptopData):
input_data = data.dict()
# Преобразование данных в DataFrame
input_df = pd.DataFrame([input_data])
# Применение One-Hot Encoding к категориальным признакам
input_df = pd.get_dummies(input_df, columns=['processor', 'os'], drop_first=True)
# Добавление отсутствующих признаков, если они есть
for col in feature_columns:
if col not in input_df.columns and col != 'price':
input_df[col] = 0
# Упорядочивание колонок согласно обучающей выборке
input_df = input_df[feature_columns]
# Предсказание цены
predicted_price = model.predict(input_df)[0]
return {"predicted_price": round(predicted_price, 2)}
# Подключение маршрутов
app.include_router(controller.router)

View File

@ -5,9 +5,8 @@ class Laptop(Base):
__tablename__ = "laptops"
id = Column(Integer, primary_key=True, index=True)
title = Column(String, index=True)
price = Column(Float)
processor = Column(String)
processor = Column(String, index=True)
ram = Column(Integer)
os = Column(String, index=True)
ssd = Column(Integer)
display = Column(Float)

View File

@ -0,0 +1,9 @@
fastapi
uvicorn
sqlalchemy
psycopg2-binary
pydantic
joblib
pandas
numpy
python-dotenv

View File

@ -7,7 +7,7 @@ import joblib
import re
# Шаг 1: Загрузка данных
df = pd.read_csv('laptops.csv')
df = pd.read_csv('../laptops.csv')
# Шаг 2: Проверка и очистка имен столбцов
print("Имена столбцов до очистки:")
@ -134,10 +134,10 @@ for name, mdl in models.items():
print(f"{name} - MAE: {mae}, RMSE: {rmse}, R²: {r2}")
# Шаг 12: Сохранение модели
joblib.dump(model, 'laptop_price_model.pkl')
joblib.dump(model, '../laptop_price_model.pkl')
print("\nМодель сохранена как 'laptop_price_model.pkl'.")
# Дополнительно: Сохранение колонок, полученных после One-Hot Encoding, для использования в бэкенде
feature_columns = X.columns.tolist()
joblib.dump(feature_columns, 'feature_columns.pkl')
joblib.dump(feature_columns, '../feature_columns.pkl')
print("Сохранены названия признаков в 'feature_columns.pkl'.")

View File

@ -0,0 +1,40 @@
import pandas as pd
import joblib
from typing import List, Dict
from schemas.schemas import LaptopCreate, LaptopResponse, PredictPriceResponse
class LaptopService:
def __init__(self, model_path: str, feature_columns_path: str):
try:
self.model = joblib.load(model_path)
except FileNotFoundError:
raise Exception(f"Model file not found at {model_path}")
except Exception as e:
raise Exception(f"Error loading model: {str(e)}")
try:
self.feature_columns = joblib.load(feature_columns_path)
except FileNotFoundError:
raise Exception(f"Feature columns file not found at {feature_columns_path}")
except Exception as e:
raise Exception(f"Error loading feature columns: {str(e)}")
def predict_price(self, data: Dict[str, any]) -> PredictPriceResponse:
# Преобразование данных в DataFrame
input_df = pd.DataFrame([data])
# Применение One-Hot Encoding к категориальным признакам
input_df = pd.get_dummies(input_df, columns=['processor', 'os'], drop_first=True)
# Добавление отсутствующих признаков, если они есть
for col in self.feature_columns:
if col not in input_df.columns and col != 'price':
input_df[col] = 0
# Упорядочивание колонок согласно обучающей выборке
input_df = input_df[self.feature_columns]
# Предсказание цены
predicted_price = self.model.predict(input_df)[0]
return PredictPriceResponse(predicted_price=round(predicted_price, 2))