price-builder-backend/services/ml/modelBuilder.py

114 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import PolynomialFeatures, StandardScaler
import matplotlib.pyplot as plt
import joblib
import numpy as np
# Шаг 1: Загрузка данных
df = pd.read_csv('../../datasets/synthetic_laptops.csv')
# Шаг 2: Проверка и очистка имен столбцов
df.columns = df.columns.str.strip().str.lower()
# Шаг 3: Проверка наличия необходимых столбцов
required_columns = [
'brand', 'processor', 'ram', 'os', 'ssd', 'display',
'gpu', 'weight', 'battery_size', 'release_year', 'display_type', 'price'
]
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
raise Exception(f"Отсутствуют столбцы: {missing_columns}")
# Шаг 4: Удаление строк с пропущенными значениями
df = df.dropna(subset=required_columns)
# Шаг 5: Очистка и преобразование колонок
def clean_numeric_column(column, remove_chars=['', ',', ' ']):
if column.dtype == object:
for char in remove_chars:
column = column.str.replace(char, '', regex=False)
return pd.to_numeric(column, errors='coerce')
else:
return column
numerical_columns = ['ram', 'ssd', 'display', 'weight', 'battery_size', 'release_year']
for col in numerical_columns:
df[col] = clean_numeric_column(df[col])
df = df.dropna(subset=['price'] + numerical_columns)
# Шаг 6: Преобразование категориальных переменных с помощью One-Hot Encoding
categorical_features = ['brand', 'processor', 'os', 'gpu', 'display_type']
df = pd.get_dummies(df, columns=categorical_features, drop_first=True)
# Шаг 7: Разделение данных на X и y
X = df.drop('price', axis=1)
y = df['price']
# Шаг 8: Создание полиномиальных и интерактивных признаков степени 2
poly = PolynomialFeatures(degree=2, interaction_only=True, include_bias=False)
X_poly = poly.fit_transform(X)
# Шаг 9: Масштабирование признаков
scaler = StandardScaler()
X_poly_scaled = scaler.fit_transform(X_poly)
# Шаг 10: Разделение на обучающую и тестовую выборки
X_train, X_test, y_train, y_test = train_test_split(X_poly_scaled, y, test_size=0.5, random_state=42)
# Шаг 11: Настройка гиперпараметров с использованием GridSearchCV
param_grid = {
'n_estimators': [100, 200],
'max_depth': [10, 20],
'max_features': ['sqrt', 'log2', 0.5],
'min_samples_split': [5, 10],
'min_samples_leaf': [2, 4]
}
grid_search = GridSearchCV(RandomForestRegressor(random_state=42), param_grid, cv=3, scoring='neg_mean_absolute_error')
grid_search.fit(X_train, y_train)
# Лучшая модель
best_model = grid_search.best_estimator_
# Шаг 12: Предсказания и оценка
y_pred = best_model.predict(X_test)
mae = mean_absolute_error(y_test, y_pred)
rmse = mean_squared_error(y_test, y_pred, squared=False)
r2 = r2_score(y_test, y_pred)
print(f"Лучшие параметры: {grid_search.best_params_}")
print(f"Random Forest - MAE: {mae}, RMSE: {rmse}, R²: {r2}")
# Шаг 13: Сохранение модели
feature_columns = X.columns.tolist()
joblib.dump(feature_columns, 'feature_columns.pkl')
joblib.dump(best_model, 'laptop_price_model.pkl')
joblib.dump(poly, 'poly_transformer.pkl')
joblib.dump(scaler, 'scaler.pkl')
print("Модель, трансформер и скейлер сохранены.")
# Шаг 14: Важность признаков
# Количество признаков, которые нужно отобразить
top_n = 15
# Важность признаков
importances = best_model.feature_importances_
indices = np.argsort(importances)[::-1]
# Отображаем только топ-N признаков
top_indices = indices[:top_n]
top_importances = importances[top_indices]
top_features = np.array(poly.get_feature_names_out())[top_indices]
# Построение графика
plt.figure(figsize=(12, 8))
plt.title(f"Топ-{top_n} признаков по важности (Random Forest)")
plt.bar(range(top_n), top_importances, align='center')
plt.xticks(range(top_n), top_features, rotation=45, ha='right')
plt.xlabel("Признаки")
plt.ylabel("Важность")
plt.tight_layout()
plt.show()