price-builder-backend/services/service.py

55 lines
2.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import joblib
from typing import List, Dict
from schemas.schemas import LaptopCreate, LaptopResponse, PredictPriceResponse
class LaptopService:
def __init__(self, model_path: str, feature_columns_path: str, poly_path: str, scaler_path: str):
try:
self.model = joblib.load(model_path)
except FileNotFoundError:
raise Exception(f"Model file not found at {model_path}")
except Exception as e:
raise Exception(f"Error loading model: {str(e)}")
try:
self.feature_columns = joblib.load(feature_columns_path)
except FileNotFoundError:
raise Exception(f"Feature columns file not found at {feature_columns_path}")
except Exception as e:
raise Exception(f"Error loading feature columns: {str(e)}")
try:
self.poly_transformer = joblib.load(poly_path)
self.scaler = joblib.load(scaler_path)
except FileNotFoundError:
raise Exception("Polynomial transformer or scaler file not found.")
except Exception as e:
raise Exception(f"Error loading polynomial transformer or scaler: {str(e)}")
def predict_price(self, data: Dict[str, any]) -> PredictPriceResponse:
# Преобразование данных в DataFrame
input_df = pd.DataFrame([data])
# Применение One-Hot Encoding к категориальным признакам
input_df = pd.get_dummies(input_df, columns=['processor', 'os'], drop_first=True)
# Добавление отсутствующих признаков
for col in self.feature_columns:
if col not in input_df.columns and col != 'price':
input_df[col] = 0
# Упорядочивание колонок
input_df = input_df[self.feature_columns]
# Преобразование с использованием PolynomialFeatures
input_poly = self.poly_transformer.transform(input_df)
# Масштабирование данных
input_scaled = self.scaler.transform(input_poly)
# Предсказание цены
predicted_price = self.model.predict(input_scaled)[0]
return PredictPriceResponse(predicted_price=round(predicted_price, 2))