873 lines
219 KiB
Plaintext
873 lines
219 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 29,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.pipeline import make_pipeline\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"from sklearn.linear_model import LinearRegression\n",
|
|||
|
"from sklearn.preprocessing import PolynomialFeatures\n",
|
|||
|
"from sklearn.metrics import mean_squared_error\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"from scipy import stats\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.tree import DecisionTreeRegressor\n",
|
|||
|
"from sklearn.ensemble import GradientBoostingRegressor\n",
|
|||
|
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
|
|||
|
"\n",
|
|||
|
"df = pd.read_csv(\"..//static//csv//balanced_neo.csv\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## **1-я бизнес-цель (регрессия)**: \n",
|
|||
|
"\n",
|
|||
|
"Предсказание скорости космического объекта для принятия решения о том, насколько опасным он может быть и стоит ли вести за ним наблюдения"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Целевой признак: скорость космического объекта relative_velocity\n",
|
|||
|
"\n",
|
|||
|
"Вход: минимальный радиус est_diameter_min, максимальный радиус est_diameter_max, яркость объекта absolute_magnitude, расстояние от Земли miss_distance"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Достижимый уровень качества: предсказания должны иметь погрешность в среднем не более 10000 км/с. Для проверки будет использоваться метрика MAE (средняя абсолютная ошибка)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 68,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.discriminant_analysis import StandardScaler\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.model_selection import GridSearchCV, train_test_split\n",
|
|||
|
"from sklearn.metrics import roc_auc_score, confusion_matrix, accuracy_score\n",
|
|||
|
"from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier\n",
|
|||
|
"import seaborn as sns\n",
|
|||
|
"from sklearn.model_selection import cross_val_predict\n",
|
|||
|
"from sklearn.metrics import mean_squared_error\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"import sklearn.preprocessing as preproc\n",
|
|||
|
"from sklearn.linear_model import LinearRegression, Ridge\n",
|
|||
|
"from sklearn.metrics import mean_absolute_error\n",
|
|||
|
"from mlxtend.evaluate import bias_variance_decomp\n",
|
|||
|
"from sklearn.neural_network import MLPRegressor\n",
|
|||
|
"\n",
|
|||
|
"# Загрузка данных\n",
|
|||
|
"df = pd.read_csv(\"..//static//csv//balanced_neo.csv\")\n",
|
|||
|
"data = df[['est_diameter_min', 'est_diameter_max', 'absolute_magnitude', 'miss_distance', 'relative_velocity']]\n",
|
|||
|
"\n",
|
|||
|
"X = data.drop('relative_velocity', axis=1)\n",
|
|||
|
"y = data['relative_velocity']\n",
|
|||
|
"\n",
|
|||
|
"# Разделение данных на обучающую и тестовую выборки\n",
|
|||
|
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование числовых данных\n",
|
|||
|
"#заполнение пустых значений медианой\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
|||
|
"\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer)\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"#Категориальных данных нет, поэтому преобразовывать их не надо\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Общая предобработка (только числовые данные)\n",
|
|||
|
"preprocessing = ColumnTransformer(\n",
|
|||
|
" [\n",
|
|||
|
" (\"nums\", preprocessing_num, X.columns)\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Линейная регрессия"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 44,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Лучшие гиперпараметры: {'preprocessing': MinMaxScaler()}\n",
|
|||
|
"Cредняя абсолютная ошибка (MAE) = 19241.554618019443\n",
|
|||
|
"Смещение: 616083845.5088656\n",
|
|||
|
"Дисперсия: 438598.9204950822\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pipeline_lin_reg = Pipeline([\n",
|
|||
|
" ('preprocessing', preprocessing),\n",
|
|||
|
" ('model', LinearRegression())]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Определение сетки гиперпараметров (возможных знач-ий гиперпараметров) для перебора\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" #как будут масштабироваться признаки\n",
|
|||
|
" 'preprocessing': [StandardScaler(), preproc.MinMaxScaler(), preproc.MaxAbsScaler(), None]\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Создание объекта GridSearchCV для поиска лучших гиперпараметров по сетке с максимальным знач-ием \n",
|
|||
|
"# отрицательного корня из среднеквадратичной ошибки (отриц., чтобы искался не минимум, а максимум)\n",
|
|||
|
"grid_search = GridSearchCV(pipeline_lin_reg, param_grid, cv=5, scoring='neg_root_mean_squared_error', n_jobs=-1)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели с перебором гиперпараметров\n",
|
|||
|
"grid_search.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"print(\"Лучшие гиперпараметры: \", grid_search.best_params_)\n",
|
|||
|
"\n",
|
|||
|
"# Лучшая модель лин. регрессии\n",
|
|||
|
"best_model = grid_search.best_estimator_\n",
|
|||
|
"\n",
|
|||
|
"y_pred = best_model.predict(X_test)\n",
|
|||
|
"\n",
|
|||
|
"print(f'Cредняя абсолютная ошибка (MAE) = {mean_absolute_error(y_test, y_pred)}')\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"mse, bias, variance = bias_variance_decomp(best_model, X_train.values, y_train.values, X_test.values, y_test.values, loss='mse', num_rounds=200, random_seed=123)\n",
|
|||
|
"print(\"Смещение: \", bias)\n",
|
|||
|
"print(\"Дисперсия: \", variance)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Гребневая регрессия"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 46,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Лучшие гиперпараметры: {'model__alpha': 10.0, 'preprocessing': MinMaxScaler()}\n",
|
|||
|
"Cредняя абсолютная ошибка (MAE) = 19239.098934204343\n",
|
|||
|
"Смещение: 615921619.3705255\n",
|
|||
|
"Дисперсия: 326886.495836047\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pipeline_ridge = Pipeline([\n",
|
|||
|
" ('preprocessing', preprocessing),\n",
|
|||
|
" ('model', Ridge())]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Определение сетки гиперпараметров (возможных знач-ий гиперпараметров) для перебора\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" #как будут масштабироваться признаки\n",
|
|||
|
" 'preprocessing': [StandardScaler(), preproc.MinMaxScaler(), preproc.MaxAbsScaler(), None],\n",
|
|||
|
" #сила регуляризации\n",
|
|||
|
" 'model__alpha': [0, 0.5, 1.0, 1.5, 2.0, 5.0, 10.0] \n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Создание объекта GridSearchCV для поиска лучших гиперпараметров по сетке с максимальным знач-ием \n",
|
|||
|
"# отрицательного корня из среднеквадратичной ошибки (отриц., чтобы искался не минимум, а максимум)\n",
|
|||
|
"grid_search = GridSearchCV(pipeline_ridge, param_grid, cv=5, scoring='neg_root_mean_squared_error', n_jobs=-1, verbose=0)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели с перебором гиперпараметров\n",
|
|||
|
"grid_search.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"print(\"Лучшие гиперпараметры: \", grid_search.best_params_)\n",
|
|||
|
"\n",
|
|||
|
"# Лучшая модель регрессии\n",
|
|||
|
"best_model = grid_search.best_estimator_\n",
|
|||
|
"\n",
|
|||
|
"y_pred = best_model.predict(X_test)\n",
|
|||
|
"\n",
|
|||
|
"print(f'Cредняя абсолютная ошибка (MAE) = {mean_absolute_error(y_test, y_pred)}')\n",
|
|||
|
"\n",
|
|||
|
"mse, bias, variance = bias_variance_decomp(best_model, X_train.values, y_train.values, X_test.values, y_test.values, loss='mse', num_rounds=200, random_seed=123)\n",
|
|||
|
"print(\"Смещение: \", bias)\n",
|
|||
|
"print(\"Дисперсия: \", variance)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Гребневая регрессия показала почти такие же результаты, что и линейная регрессия"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"MLP"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\AI labs\\aimenv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:690: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (500) reached and the optimization hasn't converged yet.\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Лучшие гиперпараметры: {'model__activation': 'relu', 'model__alpha': 0.0001, 'preprocessing': StandardScaler()}\n",
|
|||
|
"Cредняя абсолютная ошибка (MAE) = 19363.27371661712\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\AI labs\\aimenv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:690: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (500) reached and the optimization hasn't converged yet.\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\AI labs\\aimenv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:690: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (500) reached and the optimization hasn't converged yet.\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\AI labs\\aimenv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:690: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (500) reached and the optimization hasn't converged yet.\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\AI labs\\aimenv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:697: UserWarning: Training interrupted by user.\n",
|
|||
|
" warnings.warn(\"Training interrupted by user.\")\n",
|
|||
|
"c:\\AI labs\\aimenv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:690: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (500) reached and the optimization hasn't converged yet.\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\AI labs\\aimenv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:690: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (500) reached and the optimization hasn't converged yet.\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\AI labs\\aimenv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:697: UserWarning: Training interrupted by user.\n",
|
|||
|
" warnings.warn(\"Training interrupted by user.\")\n",
|
|||
|
"c:\\AI labs\\aimenv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:690: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (500) reached and the optimization hasn't converged yet.\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\AI labs\\aimenv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:690: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (500) reached and the optimization hasn't converged yet.\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\AI labs\\aimenv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:690: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (500) reached and the optimization hasn't converged yet.\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Создание пайплайна для модели MLP\n",
|
|||
|
"pipeline_mlp = Pipeline([\n",
|
|||
|
" ('preprocessing', StandardScaler()), # По умолчанию используем StandardScaler\n",
|
|||
|
" ('model', MLPRegressor(random_state=42, max_iter=500)) # MLP модель с фиксированным random_state и количеством итераций\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Определение сетки гиперпараметров для перебора\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" 'preprocessing': [StandardScaler(), preproc.MinMaxScaler()], # Разные способы масштабирования признаков\n",
|
|||
|
" 'model__hidden_layer_sizes': [(50,), (100,)], # Разные конфигурации слоев\n",
|
|||
|
" #'model__activation': ['relu', 'tanh'],\n",
|
|||
|
" 'model__alpha': [0.0001, 0.001], # Разные значения регуляризации\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Создание объекта GridSearchCV для поиска лучших гиперпараметров\n",
|
|||
|
"grid_search = GridSearchCV(pipeline_mlp, param_grid, cv=5, scoring='neg_root_mean_squared_error', n_jobs=-1, verbose=0)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели с перебором гиперпараметров\n",
|
|||
|
"grid_search.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"print(\"Лучшие гиперпараметры: \", grid_search.best_params_)\n",
|
|||
|
"\n",
|
|||
|
"# Лучшая модель MLP\n",
|
|||
|
"best_model = grid_search.best_estimator_\n",
|
|||
|
"\n",
|
|||
|
"y_pred = best_model.predict(X_test)\n",
|
|||
|
"\n",
|
|||
|
"print(f'Cредняя абсолютная ошибка (MAE) = {mean_absolute_error(y_test, y_pred)}')\n",
|
|||
|
"\n",
|
|||
|
"# Расчет смещения и дисперсии\n",
|
|||
|
"mse, bias, variance = bias_variance_decomp(best_model, X_train.values, y_train.values, X_test.values, y_test.values, loss='mse', num_rounds=200, random_seed=123)\n",
|
|||
|
"print(\"Смещение: \", bias)\n",
|
|||
|
"print(\"Дисперсия: \", variance)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Метод градиентного бустинга (набор деревьев решений)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Лучшие гиперпараметры: {'model__learning_rate': 0.1, 'model__max_depth': 3, 'model__n_estimators': 100, 'preprocessing': None}\n",
|
|||
|
"Cредняя абсолютная ошибка (MAE) = 18905.987766249527\n",
|
|||
|
"Смещение: -3.2312558004292335\n",
|
|||
|
"Дисперсия: 162393666.8715257\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Конвейер\n",
|
|||
|
"pipeline_grad = Pipeline([\n",
|
|||
|
" ('preprocessing', preprocessing),\n",
|
|||
|
" ('model', GradientBoostingRegressor())\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Определение сетки гиперпараметров\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" 'preprocessing': [StandardScaler(), preproc.MinMaxScaler(), preproc.MaxAbsScaler(), None],\n",
|
|||
|
" 'model__n_estimators': [100, 200, 300],\n",
|
|||
|
" #Скорость обучения\n",
|
|||
|
" 'model__learning_rate': [0.1, 0.2],\n",
|
|||
|
" #Максимальная глубина дерева\n",
|
|||
|
" 'model__max_depth': [3, 5, 7]\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Создание объекта GridSearchCV\n",
|
|||
|
"grid_search = GridSearchCV(pipeline_grad, param_grid, cv=2, scoring='neg_root_mean_squared_error', n_jobs=-1)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели с перебором гиперпараметров\n",
|
|||
|
"grid_search.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"print(\"Лучшие гиперпараметры: \", grid_search.best_params_)\n",
|
|||
|
"\n",
|
|||
|
"# Лучшая модель случайного леса\n",
|
|||
|
"best_model = grid_search.best_estimator_\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"y_pred = best_model.predict(X_test)\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"print(f'Cредняя абсолютная ошибка (MAE) = {mean_absolute_error(y_test, y_pred)}')\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Получение предсказаний на кросс-валидации\n",
|
|||
|
"y_cv_pred = cross_val_predict(best_model, X_train, y_train, cv=3)\n",
|
|||
|
"\n",
|
|||
|
"# Оценка смещения\n",
|
|||
|
"bias = np.mean(y_cv_pred - y_train)\n",
|
|||
|
"\n",
|
|||
|
"# Оценка дисперсии\n",
|
|||
|
"variance = np.var(y_cv_pred)\n",
|
|||
|
"\n",
|
|||
|
"print(f\"Смещение: {bias}\")\n",
|
|||
|
"print(f\"Дисперсия: {variance}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"**Вывод**: \n",
|
|||
|
"\n",
|
|||
|
"Все 3 модели регрессии не показали необходимого уровня \"погрешности\". Это означает, что необходимо использовать более сложные модели или что по доступным данным нельзя достичь необходимой погрешности.\n",
|
|||
|
"\n",
|
|||
|
"Из всех моделей градиентный бустинг показал самую низкую \"погрешность\""
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## **2-я бизнес-цель (классификация):** \n",
|
|||
|
"\n",
|
|||
|
"Определение опасности космиеского объекта для увеличения безопасности Земли"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Целевой признак: опасность объекта hazardous\n",
|
|||
|
"\n",
|
|||
|
"Вход: минимальный радиус est_diameter_min, максимальный радиус est_diameter_max, яркость объекта absolute_magnitude, скорость relative_velocity"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Достижимый уровень качества: необходимо, чтобы точность предсказания модели составляла не менее 90%. Для проверки этого будет использована метрика Accuracy"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 57,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.discriminant_analysis import StandardScaler\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.preprocessing import OneHotEncoder\n",
|
|||
|
"from sklearn.linear_model import LogisticRegression\n",
|
|||
|
"from sklearn.model_selection import GridSearchCV, train_test_split\n",
|
|||
|
"from sklearn.metrics import roc_auc_score, confusion_matrix, accuracy_score\n",
|
|||
|
"from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier\n",
|
|||
|
"import seaborn as sns\n",
|
|||
|
"from sklearn.model_selection import cross_val_predict\n",
|
|||
|
"from sklearn.metrics import mean_squared_error\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"# Загрузка данных\n",
|
|||
|
"df = pd.read_csv(\"..//static//csv//balanced_neo.csv\")\n",
|
|||
|
"data = df[['est_diameter_min', 'est_diameter_max', 'absolute_magnitude', 'relative_velocity', 'hazardous']]\n",
|
|||
|
"\n",
|
|||
|
"X = data.drop('hazardous', axis=1)\n",
|
|||
|
"y = data['hazardous']\n",
|
|||
|
"\n",
|
|||
|
"# Разделение данных на обучающую и тестовую выборки\n",
|
|||
|
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование числовых данных\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Общая предобработка (только числовые данные)\n",
|
|||
|
"preprocessing = ColumnTransformer(\n",
|
|||
|
" [\n",
|
|||
|
" (\"nums\", preprocessing_num, X.columns),\n",
|
|||
|
" ]\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Логистическая регрессия"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 66,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Лучшие гиперпараметры: {'classifier__C': 0.1, 'classifier__penalty': 'l1', 'classifier__solver': 'liblinear'}\n",
|
|||
|
"ROC у логистической регрессии = 0.8670867396912991\n",
|
|||
|
"Точность = 0.8591628959276018\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAGwCAYAAABVdURTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA6bElEQVR4nO3df5zNdf7///uc4cwvZrCzZgZTgyKWCPEZKW+ZGv0Q21Yqy6TSVkytWRXCUMJWpJWyKYl3LWn7YdF4Z4pFdhWmX8RipDCjWTXDYIZznt8/fJ1mxozmjPPzdW7Xy+VcLs5rXq9zHuflx7l7vh7P5yvMGGMEAABgETZ/FwAAAOBJhBsAAGAphBsAAGAphBsAAGAphBsAAGAphBsAAGAphBsAAGAp9fxdgK85nU4dOHBADRs2VFhYmL/LAQAAtWCM0ZEjR9SsWTPZbOcemwm5cHPgwAElJyf7uwwAAFAH3333nVq0aHHOfUIu3DRs2FDS6ZMTGxvr52oAAEBtlJSUKDk52fU9fi4hF27OXIqKjY0l3AAAEGRq01JCQzEAALAUwg0AALAUwg0AALAUwg0AALAUwg0AALAUwg0AALAUwg0AALAUwg0AALAUwg0AALAUwg0AALAUv4abf/7zn+rfv7+aNWumsLAwvffee794zJo1a9SlSxdFRETooosu0oIFC7xeJwAACB5+DTelpaXq1KmT5syZU6v98/PzdcMNN6hPnz7Ky8vTH//4R917771atWqVlysFAADBwq83zrzuuut03XXX1Xr/uXPnqmXLlpoxY4YkqV27dlq/fr2ee+45paene6tMAADcZozR8ZMOf5fhN1H1w2t1k0tvCKq7gm/cuFFpaWmVtqWnp+uPf/xjjceUlZWprKzM9bykpMRb5QEBLdT/oQV8yRjp1rkbte1g6H7nbHsiXdF2/8SMoAo3BQUFSkhIqLQtISFBJSUlOn78uKKios46Ztq0aZo8ebKvSkQICaawwD+0AEJJUIWbuhg7dqyysrJcz0tKSpScnOzHihCsKoYZwgKA2mifFKul96fKT1dn/Cqqfrjf3juowk1iYqIKCwsrbSssLFRsbGy1ozaSFBERoYiICF+UBwupOipjlTATyv/QAv7gz76TUBZU4SY1NVUrV66stO3DDz9UamqqnyqC1RhjdKzcUesgE2xhgX9oAYQCv4abo0ePateuXa7n+fn5ysvLU5MmTXTBBRdo7Nix2r9/vxYuXChJuv/++/XCCy/o0Ucf1d13362PPvpIb731llasWOGvjwALcTqNbpy9/pyhpmqYISwAQODxa7j57LPP1KdPH9fzM70xGRkZWrBggQ4ePKh9+/a5ft6yZUutWLFCo0aN0vPPP68WLVrolVdeYRo4zpvTadR35lrlF5W6tlU3KkOYAYDAF2aMMf4uwpdKSkoUFxen4uJixcbG+rsc+NmZy1A3zl7vCjYt42O0PLOXou0EGQAIFO58fwdVzw3gKTX11rSMj1FuVm/ZbIQaAAhWhBuEnJp6a9onxWp5Zi+CDQAEOcINQooxZwebM701XIYCAGsg3CCkHCt3uIINvTUAYE2EG4QMY4xunbvR9Xx5Zi/FRPBXAACsxubvAgBfqThq0z4pVtF2/y0NDgDwHsINQkLVUZvT69dwKQoArIhwg5Bw/CSjNgAQKgg3CDmM2gCAtRFuEHLINQBgbYQbAABgKYQbAABgKYQbhITQuj0sAIQ2wg0sr+o0cACAtRFuYHlVF++Lqs80cACwMsINLO3MHcDPYBo4AFgf4QaW5XQa9Z25VvlFpZJYvA8AQgXhBpZkzOkRmzPB5swdwBm1AQDrI9zAkirebqFlfIxys3rLZiPYAEAoINzAkipO/V6e2YtgAwAhhHADy6k69ZsrUQAQWgg3sJyqdwBn6jcAhBbCDSyNqd8AEHoIN7A0cg0AhB7CDSyH+0gBQGgj3MBSqq5IDAAIPYQbWEbVhftoJgaA0ES4gWVUXbiPFYkBIDQRbmBJLNwHAKGLcAPLqNhIzIANAIQuwg0soeqqxACA0EW4gSWwKjEA4AzCDSyHVYkBILQRbmAJ9NsAAM4g3CDo0W8DAKiIcIOgR78NAKAiwg0shX4bAADhBkGPfhsAQEWEGwQ1+m0AAFURbhDU6LcBAFRFuIFl0G8DAJAIN7AQcg0AQCLcIMhVbCYGAEAi3CCIOZ1GN85e7+8yAAABhnCDoOR0GvWduVb5RaWSaCYGAPyMcIOgY8zpEZszwaZlfIyWZ/aimRgAIIlwgyB0rPzn6d8t42OUm9VbNhvBBgBwGuEGQaXqon3LM3sRbAAAlRBuEFSqLtoXbafPBgBQGeEGQYtF+wAA1SHcIGiRawAA1SHcIKiwaB8A4JcQbhA0uAM4AKA2CDcIGtwBHABQG4QbBCWaiQEANSHcICiRawAANSHcAAAAS/F7uJkzZ45SUlIUGRmpHj16aNOmTefcf9asWWrbtq2ioqKUnJysUaNG6cSJEz6qFgAABDq/hpslS5YoKytL2dnZ2rJlizp16qT09HQdOnSo2v3ffPNNjRkzRtnZ2dq+fbteffVVLVmyROPGjfNx5fAHpoEDAGrDr+Fm5syZGj58uIYNG6b27dtr7ty5io6O1vz586vd/5NPPtEVV1yhO++8UykpKbr22mt1xx13nHO0p6ysTCUlJZUeCD5MAwcA1Jbfwk15ebk2b96stLS0n4ux2ZSWlqaNG6v/EuvZs6c2b97sCjN79uzRypUrdf3119f4PtOmTVNcXJzrkZyc7NkPAp9gGjgAoLbq+euNi4qK5HA4lJCQUGl7QkKCvvnmm2qPufPOO1VUVKRevXrJGKNTp07p/vvvP+dlqbFjxyorK8v1vKSkhIAT5JgGDgA4F783FLtjzZo1mjp1ql588UVt2bJF77zzjlasWKEnn3yyxmMiIiIUGxtb6YHgU7HfhlwDADgXv43cxMfHKzw8XIWFhZW2FxYWKjExsdpjJkyYoCFDhujee++VJHXs2FGlpaW677779Pjjj8tmC6qshlqi3wYA4A6/pQG73a6uXbsqNzfXtc3pdCo3N1epqanVHnPs2LGzAkx4+OneC8NUGsui3wYA4A6/jdxIUlZWljIyMtStWzd1795ds2bNUmlpqYYNGyZJGjp0qJo3b65p06ZJkvr376+ZM2fqsssuU48ePbRr1y5NmDBB/fv3d4UcWBv9NgCAX+LXcDNo0CD98MMPmjhxogoKCtS5c2fl5OS4moz37dtXaaRm/PjxCgsL0/jx47V//379+te/Vv/+/fXUU0/56yPAB+i3AQC4I8yE2PWckpISxcXFqbi4mObiIGCM0Q1/We+6LLXtiXRF2/2ayQEAfuDO9zcduAho9NsAANxFuEHQoN8GAFAbhBsENPptAADuItwgYLG+DQCgLgg3CFjHyum3AQC4j3CDgOR0Gt04e73rOf02AIDaItwg4BhzOtjkF5VKOj1qE21n1AYAUDuEGwScitO/W8bHaHlmL0ZtAAC1RrhBQFue2Us2G8EGAFB7hBsEHKZ/AwDOB+EGAYXp3wCA80W4QUDhdgsAgPNFuEHAYvo3AKAuCDcIKPTbAADOF+EGAYN+GwCAJxBuEDC43QIAwBMINwgIVUdt6LcBANQV4QYBoeqoDbdbAADUFeEGfseoDQDAkwg38Luqa9swagMAOB+EGwQURm0AAOeLcIOAQq4BAJwvwg0AALAUwg38ruKqxAAAnC/CDfyKVYkBAJ5GuIFfcRdwAICnEW4QMJgpBQDwBMINAga5BgDgCYQbAABgKYQbAABgKYQbAABgKYQb+BVr3AAAPI1wA79hjRsAgDecV7g5ceKEp+pACDpWzho3AADPczvcOJ1OPfnkk2revLkaNGigPXv2SJImTJigV1991eMFwpqqjtqwxg0AwFPcDjdTpkzRggUL9PTTT8tut7u2d+jQQa+88opHi4N1VV2ZONrOqA0AwDPcDjcLFy7Uyy+/rMGDBys8/OcvpE6dOumbb77xaHEIDYzaAAA8ye1ws3//fl100UVnbXc6nTp58qRHikJoIdcAADzJ7XDTvn17rVu
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAApIAAAIjCAYAAACwHvu2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB2x0lEQVR4nO3deVxU5fv/8dcgsggCoiJiLrjjvpXhblLmvn7KMsWlzTRzXyrNNQpzSS3NTFHTzNxSK9M0JY3ctxRNUyMX1EREXJBlfn/4Y75OqMHIcBDez8/jPD7Mfe65z3UGHa+uc5/7mMxmsxkRERERkQxyMDoAEREREXk0KZEUEREREZsokRQRERERmyiRFBERERGbKJEUEREREZsokRQRERERmyiRFBERERGbKJEUEREREZsokRQRERERmyiRFJEHOn78OM888wyenp6YTCZWr16dqeOfPn0ak8lEWFhYpo77KGvSpAlNmjQxOgwRkf+kRFLkEfDnn3/y2muvUbp0aVxcXPDw8KB+/fp8/PHH3Lx5067HDg4O5tChQ0ycOJFFixZRp04dux4vK/Xo0QOTyYSHh8c9P8fjx49jMpkwmUx89NFHGR7/3LlzjBkzhv3792dCtCIi2Y+j0QGIyIN99913/O9//8PZ2Znu3btTpUoVbt++zbZt2xg6dCiHDx9mzpw5djn2zZs3iYiI4J133qFfv352OUbJkiW5efMmefPmtcv4/8XR0ZEbN26wdu1annvuOat9ixcvxsXFhVu3btk09rlz5xg7diylSpWiRo0a6X7fhg0bbDqeiEhWUyIpko2dOnWKLl26ULJkSTZv3kzRokUt+/r27cuJEyf47rvv7Hb8S5cuAeDl5WW3Y5hMJlxcXOw2/n9xdnamfv36fPXVV2kSySVLltCqVStWrFiRJbHcuHGDfPny4eTklCXHExF5WLq0LZKNhYaGEh8fzxdffGGVRKYqW7Ysb731luV1UlIS48ePp0yZMjg7O1OqVCnefvttEhISrN5XqlQpWrduzbZt23jiiSdwcXGhdOnSLFy40NJnzJgxlCxZEoChQ4diMpkoVaoUcOeScOrPdxszZgwmk8mqbePGjTRo0AAvLy/c3d2pUKECb7/9tmX//eZIbt68mYYNG+Lm5oaXlxft2rUjMjLynsc7ceIEPXr0wMvLC09PT3r27MmNGzfu/8H+y4svvsgPP/xAbGyspW3Xrl0cP36cF198MU3/mJgYhgwZQtWqVXF3d8fDw4MWLVpw4MABS58tW7bw+OOPA9CzZ0/LJfLU82zSpAlVqlRhz549NGrUiHz58lk+l3/PkQwODsbFxSXN+Tdv3pwCBQpw7ty5dJ+riEhmUiIpko2tXbuW0qVLU69evXT1f/nllxk9ejS1atVi6tSpNG7cmJCQELp06ZKm74kTJ+jcuTNPP/00kydPpkCBAvTo0YPDhw8D0LFjR6ZOnQrACy+8wKJFi5g2bVqG4j98+DCtW7cmISGBcePGMXnyZNq2bcv27dsf+L6ffvqJ5s2bc/HiRcaMGcOgQYP49ddfqV+/PqdPn07T/7nnnuPatWuEhITw3HPPERYWxtixY9MdZ8eOHTGZTKxcudLStmTJEipWrEitWrXS9D958iSrV6+mdevWTJkyhaFDh3Lo0CEaN25sSeoCAgIYN24cAK+++iqLFi1i0aJFNGrUyDLO5cuXadGiBTVq1GDatGk0bdr0nvF9/PHHFC5cmODgYJKTkwH47LPP2LBhAzNmzMDPzy/d5yoikqnMIpItXb161QyY27Vrl67++/fvNwPml19+2ap9yJAhZsC8efNmS1vJkiXNgDk8PNzSdvHiRbOzs7N58ODBlrZTp06ZAfOkSZOsxgwODjaXLFkyTQzvvfee+e6vlalTp5oB86VLl+4bd+ox5s+fb2mrUaOG2cfHx3z58mVL24EDB8wODg7m7t27pzler169rMbs0KGDuWDBgvc95t3n4ebmZjabzebOnTubmzVrZjabzebk5GSzr6+veezYsff8DG7dumVOTk5Ocx7Ozs7mcePGWdp27dqV5txSNW7c2AyYZ8+efc99jRs3tmr78ccfzYB5woQJ5pMnT5rd3d3N7du3/89zFBGxJ1UkRbKpuLg4APLnz5+u/t9//z0AgwYNsmofPHgwQJq5lJUqVaJhw4aW14ULF6ZChQqcPHnS5pj/LXVu5bfffktKSkq63nP+/Hn2799Pjx498Pb2trRXq1aNp59+2nKed3v99detXjds2JDLly9bPsP0ePHFF9myZQvR0dFs3ryZ6Ojoe17WhjvzKh0c7nx9Jicnc/nyZctl+71796b7mM7OzvTs2TNdfZ955hlee+01xo0bR8eOHXFxceGzzz5L97FEROxBiaRINuXh4QHAtWvX0tX/r7/+wsHBgbJly1q1+/r64uXlxV9//WXVXqJEiTRjFChQgCtXrtgYcVrPP/889evX5+WXX6ZIkSJ06dKFZcuWPTCpTI2zQoUKafYFBATwzz//cP36dav2f59LgQIFADJ0Li1btiR//vx8/fXXLF68mMcffzzNZ5kqJSWFqVOnUq5cOZydnSlUqBCFCxfm4MGDXL16Nd3HLFasWIZurPnoo4/w9vZm//79TJ8+HR8fn3S/V0TEHpRIimRTHh4e+Pn58fvvv2foff++2eV+8uTJc892s9ls8zFS5++lcnV1JTw8nJ9++olu3bpx8OBBnn/+eZ5++uk0fR/Gw5xLKmdnZzp27MiCBQtYtWrVfauRAO+//z6DBg2iUaNGfPnll/z4449s3LiRypUrp7vyCnc+n4zYt28fFy9eBODQoUMZeq+IiD0okRTJxlq3bs2ff/5JRETEf/YtWbIkKSkpHD9+3Kr9woULxMbGWu7AzgwFChSwusM51b+rngAODg40a9aMKVOmcOTIESZOnMjmzZv5+eef7zl2apzHjh1Ls+/o0aMUKlQINze3hzuB+3jxxRfZt28f165du+cNSqmWL19O06ZN+eKLL+jSpQvPPPMMQUFBaT6T9Cb16XH9+nV69uxJpUqVePXVVwkNDWXXrl2ZNr6IiC2USIpkY8OGDcPNzY2XX36ZCxcupNn/559/8vHHHwN3Ls0Cae6snjJlCgCtWrXKtLjKlCnD1atXOXjwoKXt/PnzrFq1yqpfTExMmvemLsz97yWJUhUtWpQaNWqwYMECq8Ts999/Z8OGDZbztIemTZsyfvx4Zs6cia+v73375cmTJ02185tvvuHs2bNWbakJ772S7owaPnw4UVFRLFiwgClTplCqVCmCg4Pv+zmKiGQFLUguko2VKVOGJUuW8PzzzxMQEGD1ZJtff/2Vb775hh49egBQvXp1goODmTNnDrGxsTRu3JidO3eyYMEC2rdvf9+lZWzRpUsXhg8fTocOHejfvz83btxg1qxZlC9f3upmk3HjxhEeHk6rVq0oWbIkFy9e5NNPP+Wxxx6jQYMG9x1/0qRJtGjRgsDAQHr37s3NmzeZMWMGnp6ejBkzJtPO498cHBx49913/7Nf69atGTduHD179qRevXocOnSIxYsXU7p0aat+ZcqUwcvLi9mzZ5M/f37c3NyoW7cu/v7+GYpr8+bNfPrpp7z33nuW5Yjmz59PkyZNGDVqFKGhoRkaT0Qks6giKZLNtW3bloMHD9K5c2e+/fZb+vbty4gRIzh9+jSTJ09m+vTplr5z585l7Nix7Nq1iwEDBrB582ZGjhzJ0qVLMzWmggULsmrVKvLly8ewYcNYsGABISEhtGnTJk3sJUqUYN68efTt25dPPvmERo0asXnzZjw9Pe87flBQEOvXr6dgwYKMHj2ajz76iCeffJLt27dnOAmzh7fffpvBgwfz448/8tZbb7F3716+++47ihcvbtUvb968LFiwgDx58vD666/zwgsvsHXr1gwd69q1a/Tq1YuaNWvyzjvvWNobNmzIW2+9xeTJk/ntt98y5bxERDLKZM7IbHQRERERkf9PFUkRERERsYkSSRERERGxiRJJEREREbGJEkkRERERsYkSSRERERGxiRJJEREREbGJEkkRERERsUmOfLJ
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 800x600 with 2 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Смещение: 0.852988221106613\n",
|
|||
|
"Дисперсия: 0.006548654676149887\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Конвейер для логистической регрессии\n",
|
|||
|
"pipeline_logreg = Pipeline([\n",
|
|||
|
" ('preprocessing', preprocessing),\n",
|
|||
|
" ('classifier', LogisticRegression())\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Определение сетки гиперпараметров (возможных знач-ий гиперпараметров) для перебора\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" # Параметр регуляризации (сила регуляризации), чем меньше, тем сильнее регуляризация\n",
|
|||
|
" 'classifier__C': [0.1, 0.5, 1],\n",
|
|||
|
" # Тип регуляризации (ф-ия штрафов)\n",
|
|||
|
" 'classifier__penalty': ['l1', 'l2'],\n",
|
|||
|
" # Решатель (сам алгоритм?)\n",
|
|||
|
" 'classifier__solver': ['liblinear', 'saga']\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Создание объекта GridSearchCV для поиска лучших гиперпараметров по сетке с максимальным знач-ием ROC-кривой\n",
|
|||
|
"grid_search = GridSearchCV(pipeline_logreg, param_grid, cv=5, scoring='accuracy', n_jobs=-1)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели с перебором гиперпараметров\n",
|
|||
|
"grid_search.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"print(\"Лучшие гиперпараметры: \", grid_search.best_params_)\n",
|
|||
|
"\n",
|
|||
|
"# Лучшая модель логистической регрессии\n",
|
|||
|
"best_model = grid_search.best_estimator_\n",
|
|||
|
"\n",
|
|||
|
"# Использование и оценка лучшей логистической модели\n",
|
|||
|
"y_pred_proba = best_model.predict_proba(X_test)[:, 1]\n",
|
|||
|
"print(f'ROC у логистической регрессии = {roc_auc_score(y_test, y_pred_proba)}')\n",
|
|||
|
"\n",
|
|||
|
"y_pred = best_model.predict(X_test)\n",
|
|||
|
"print(f'Точность = {accuracy_score(y_test, y_pred)}')\n",
|
|||
|
"\n",
|
|||
|
"fpr, tpr, _ = metrics.roc_curve(y_test, y_pred_proba)\n",
|
|||
|
"\n",
|
|||
|
"# построение ROC кривой\n",
|
|||
|
"plt.plot(fpr, tpr)\n",
|
|||
|
"plt.ylabel('True Positive Rate')\n",
|
|||
|
"plt.xlabel('False Positive Rate')\n",
|
|||
|
"plt.show()\n",
|
|||
|
"\n",
|
|||
|
"# Построение матрицы ошибок\n",
|
|||
|
"conf_matrix = confusion_matrix(y_test, y_pred)\n",
|
|||
|
"\n",
|
|||
|
"# Визуализация матрицы ошибок\n",
|
|||
|
"plt.figure(figsize=(8, 6))\n",
|
|||
|
"sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', \n",
|
|||
|
" xticklabels=['Предсказанный \"безопасный\"', 'Предсказанный \"опасный\"'], \n",
|
|||
|
" yticklabels=['Действительно \"безопасный\"', 'Действительно \"опасный\"'])\n",
|
|||
|
"plt.title('Confusion Matrix')\n",
|
|||
|
"plt.ylabel('Actual')\n",
|
|||
|
"plt.xlabel('Predicted')\n",
|
|||
|
"plt.show()\n",
|
|||
|
"\n",
|
|||
|
"# Оценка дисперсии и смещения\n",
|
|||
|
"cv_results = grid_search.cv_results_\n",
|
|||
|
"mean_test_score = cv_results['mean_test_score']\n",
|
|||
|
"std_test_score = cv_results['std_test_score']\n",
|
|||
|
"\n",
|
|||
|
"print(f\"Смещение: {mean_test_score.mean()}\")\n",
|
|||
|
"print(f\"Дисперсия: {std_test_score.mean()}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Метод случаного леса (набор деревьев решений)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Лучшие гиперпараметры: {'classifier__max_depth': 20, 'classifier__min_samples_leaf': 4, 'classifier__n_estimators': 200}\n",
|
|||
|
"ROC у метода случайного леса = 0.9081081989462431\n",
|
|||
|
"Точность = 0.8718891402714932\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAGwCAYAAABVdURTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABDXElEQVR4nO3df5xM9eLH8ffs2p1d7C7a7C+rRYj8/tmSXNqiJPqpctn061aoSyqETfl1+yHdKKUkfetSXZUbURRFSlkrRSu/8nOXDbvs2l8z5/uHHG0WO2tmzs7s6/l47ON+ztlzZt5z6M7bmTPnYzMMwxAAAICfCLA6AAAAgDtRbgAAgF+h3AAAAL9CuQEAAH6FcgMAAPwK5QYAAPgVyg0AAPArVawO4G1Op1P79u1TWFiYbDab1XEAAEAZGIaho0ePKjY2VgEBZz83U+nKzb59+xQfH291DAAAUA67d+9WnTp1zrpNpSs3YWFhkk4cnPDwcIvTAACAssjJyVF8fLz5Pn42la7cnPwoKjw8nHIDAICPKcslJVxQDAAA/ArlBgAA+BXKDQAA8CuUGwAA4FcoNwAAwK9QbgAAgF+h3AAAAL9CuQEAAH6FcgMAAPwK5QYAAPgVS8vNV199pd69eys2NlY2m00fffTROfdZsWKF2rRpI7vdrosvvlhz5szxeE4AAOA7LC03ubm5atmypWbMmFGm7Xfs2KFevXqpW7duSktL0z//+U/dc889Wrp0qYeTAgAAX2HpxJnXXHONrrnmmjJvP3PmTNWrV0/PP/+8JKlJkyZatWqVXnjhBfXo0cNTMYFyKSh26FBuoRxOw+ooAOBVwVUCVDssxLLn96lZwdesWaOkpKQS63r06KF//vOfZ9ynoKBABQUF5nJOTo6n4sHP5Bc5tGH3EW09eEzBgQEyDGnCok3KyS9WeMjZ/9MxDOloQbGXkgJAxdKmbg0teLCzZc/vU+UmIyNDUVFRJdZFRUUpJydHx48fV2ho6Gn7TJ48WePHj/dWRPiB77b/rpdXbNPKLQfPuE1OftmKS5UAmwIDbO6KBgA+ISjQ2u8r+VS5KY9Ro0Zp+PDh5nJOTo7i4+MtTISKbNmmTN0z94fT1teLrKaEC6qq2Glo96E8Tb6xhaLC7Wd9rJpVg1WjapBsNsoNAHiTT5Wb6OhoZWZmlliXmZmp8PDwUs/aSJLdbpfdfvY3IeDBd9Zp8caMEut6tYhR/w511eniSItSAQDKw6fKTWJiohYvXlxi3eeff67ExESLEsHXvfPdb3riw59OWz/jjjbq1SLGgkQAgPNlabk5duyYtm7dai7v2LFDaWlpqlWrlurWratRo0Zp7969mjt3riTp/vvv1/Tp0/XYY4/prrvu0hdffKH33ntPixYtsuolwIcNeTdVn/y4v8S6ZcOvUL3I6lwnAwA+zNJy88MPP6hbt27m8slrY5KTkzVnzhzt379fu3btMn9fr149LVq0SMOGDdOLL76oOnXq6PXXX+dr4HDZh+v3lCg2g7s10CNXNVYApQYAfJ7NMIxKdROOnJwcRUREKDs7W+Hh4VbHgQXSdh9R3xmrzeXUsVepVrVgCxMBAM7Flfdvn7rmBjhfVzzzpXYdyjOXn7m5BcUGAPwME2ei0nj+s/QSxea29vG6tR23BQAAf8OZG/i9/CKHHvvgRy3csM9ct/mpngoNDrQwFQDAUyg38GsOp6EbXv5Gm/efmnbj7bs7UGwAwI9RbuC3DuTkq8Ok5SXWLX6oi5rGciE5APgzrrmBX1q9Neu0YvPuPR0pNgBQCXDmBn5pyLup5viKRhfqrUHtmeMJACoJyg38yg87D+nmmWvM5b9fVlcT+ja3MBEAwNv4WAp+49ON+0sUG0kacXVji9IAAKzCmRv4jW+2/W6O7+1ST6OvbcJHUQBQCVFu4BcMw9Db3/4mSUpOvEhP9GpqcSIAgFX4WAp+Yefvp+483OaimhYmAQBYjXIDv9D9+RXm+PqWsdYFAQBYjnIDn7cjK1cn57ZvUSeC62wAoJKj3MCnpe0+om7PrTCX5wzqYF0YAECFQLmBz8o+XqS+M1abyz0ujVKtasEWJgIAVAR8Wwo+yTAMtRz/mbk8qHOCUnpfamEiAEBFwZkb+KR6oxab4+r2KhQbAICJcgOfM/K/P5ZYTht3lUVJAAAVER9LwacYhqF53+82l7dPulYBAXw7CgBwCmdu4FN++9PN+r54pCvFBgBwGsoNfMaRvEJdPe0rc7n+hdUtTAMAqKj4WAo+YclPGbr//9aZy8GB9HIAQOl4h0CFdySvsESxqRocqOWPdLUwEQCgIuPMDSq07ONFavXU5+byhL7N9PfLLrIwEQCgouPMDSq0l5b/ao7bXVRT/TvWtTANAMAXcOYGFVqVP11b88EDnSxMAgDwFZy5QYW29OcMSdLdl9ezOAkAwFdQblBhFTuc2pGVK0k6ll9scRoAgK+g3KDCmrbs1PU2N7aJszAJAMCXUG5QIRmGoelfbjWX2yXUsjANAMCXUG5QIf1rSbo5HnXNJQpkmgUAQBlRblDhOJ2GZq7cZi4P6szFxACAsqPcoMJZtHG/OZ5xRxsFV+GvKQCg7LjPDSqU0R9u1Lvf7TKXr20ebWEaAIAv4p/EqDBWbjlYotg8e3ML2WxcawMAcA1nblBhjPloozleMeJvSoisZmEaAICv4swNKoyso4WSpBtbx1FsAADlRrlBhZCdV6TjRQ5J0tWXRlmcBgDgyyg3qBBGLvjRHHdpeKGFSQAAvo5yA8vlFzn06U8nJsisEmBTNTuXggEAyo9yA8v1fmmVOZ6V3M7CJAAAf0C5gaV2/Z6nXw8ckySFhVRRt8a1LU4EAPB1lBtYatryLeZ41ePdLUwCAPAXlBtYKm33EXMcERpkXRAAgN+g3MAyR/OLtP1griTp0R6NLU4DAPAXlBtY5sDRAnN8Y5s4C5MAAPwJ5QaWWZC6R5JUo2qQYiJCLU4DAPAXlBtYwuk0NOPLbZKkoED+GgIA3Id3FVhiwOzvzPHgvzWwMAkAwN9QbuB1eYXFWr31d3P5zs71LEwDAPA3lBt43cPz0szxzL+3tS4IAMAvUW7gdZk5+eY4qQl3JAYAuBflBl5nr3Lir920fq1UhYuJAQBuxjsLvO6H3w5LkkKC+OsHAHA/3l3gVdsPHpNhnBjbbDZrwwAA/BLlBl710fq95rhLw0gLkwAA/JXl5WbGjBlKSEhQSEiIOnbsqLVr1551+2nTpqlx48YKDQ1VfHy8hg0bpvz8/LPug4rj319slSRVCw5U1eAqFqcBAPgjS8vN/PnzNXz4cKWkpCg1NVUtW7ZUjx49dODAgVK3f/fddzVy5EilpKRo8+bNeuONNzR//nyNHj3ay8lRHh+s22OO7+/KjfsAAJ5hM4yTV0B4X8eOHdW+fXtNnz5dkuR0OhUfH6+hQ4dq5MiRp20/ZMgQbd68WcuXLzfXPfLII/ruu++0atWqUp+joKBABQWnJmjMyclRfHy8srOzFR4e7uZXhLNJGLnIHP868RqmXQAAlFlOTo4iIiLK9P5t2btLYWGh1q1bp6SkpFNhAgKUlJSkNWvWlLpPp06dtG7dOvOjq+3bt2vx4sW69tprz/g8kydPVkREhPkTHx/v3heCMvl2+6k7Eg9LakSxAQB4jGUXPWRlZcnhcCgqKqrE+qioKP3yyy+l7nPHHXcoKytLl19+uQzDUHFxse6///6zfiw1atQoDR8+3Fw+eeYG3nXba9+a4/uuqG9hEgCAv/Opfz6vWLFCkyZN0ssvv6zU1FQtWLBAixYt0tNPP33Gfex2u8LDw0v8wLtmfLnVHPduGavQ4EAL0wAA/J1lZ24iIyMVGBiozMzMEuszMzMVHR1d6j5jx47VgAEDdM8990iSmjdvrtzcXN1333164oknFBDgU12t0nh2abo5ntCnmYVJAACVgWVtIDg4WG3bti1xcbDT6dTy5cuVmJhY6j55eXmnFZjAwBNnASy8LhrnUN1+okM/d0tLRVQNsjgNAMDfWXqjkeHDhys5OVnt2rVThw4dNG3aNOXm5mrQoEGSpIEDByouLk6
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Матрица ошибок:\n",
|
|||
|
"[[1329 397]\n",
|
|||
|
" [ 56 1754]]\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAApIAAAIjCAYAAACwHvu2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB2OElEQVR4nO3dd1gUV9sG8HvovamAEKWIBewtilgjkdgVjUGNYkkswVhQVBILVhLs2EsUNRpj7JpEJViIBlFRrEiswQYYEQGVuvv94ct+bkCFleUge/+85rp2zzkz88yiy+MzM2ckuVwuBxERERFRMWmJDoCIiIiI3k9MJImIiIhIJUwkiYiIiEglTCSJiIiISCVMJImIiIhIJUwkiYiIiEglTCSJiIiISCVMJImIiIhIJUwkiYiIiEglTCSJ6I2uX7+ODh06wNzcHJIkYc+ePSW6/Tt37kCSJISFhZXodt9nbdu2Rdu2bUWHQUT0Vkwkid4DN2/exPDhw+Hs7AwDAwOYmZnBw8MDS5YswYsXL9S6b19fX1y6dAlz5szB5s2b0aRJE7XurzQNGjQIkiTBzMys0M/x+vXrkCQJkiRh/vz5xd7+gwcPEBQUhNjY2BKIloio7NERHQARvdmvv/6KTz/9FPr6+hg4cCDq1KmD7OxsnDhxAgEBAbhy5QrWrFmjln2/ePECUVFR+PbbbzFq1Ci17MPBwQEvXryArq6uWrb/Njo6Onj+/Dn279+PPn36KPVt2bIFBgYGyMzMVGnbDx48wIwZM+Do6IgGDRoUeb3Dhw+rtD8iotLGRJKoDLt9+zZ8fHzg4OCAI0eOoHLlyoo+Pz8/3LhxA7/++qva9v/o0SMAgIWFhdr2IUkSDAwM1Lb9t9HX14eHhwd++umnAonk1q1b0blzZ+zcubNUYnn+/DmMjIygp6dXKvsjInpXPLVNVIaFhIQgIyMDP/zwg1ISmc/FxQVjxoxRvM/NzcWsWbNQrVo16Ovrw9HREd988w2ysrKU1nN0dESXLl1w4sQJfPjhhzAwMICzszM2bdqkGBMUFAQHBwcAQEBAACRJgqOjI4CXp4TzX78qKCgIkiQptYWHh6Nly5awsLCAiYkJatasiW+++UbR/7prJI8cOYJWrVrB2NgYFhYW6N69O+Li4grd340bNzBo0CBYWFjA3NwcgwcPxvPnz1//wf5Hv3798PvvvyM1NVXRdubMGVy/fh39+vUrMD4lJQUTJkxA3bp1YWJiAjMzM3Ts2BEXLlxQjDl27BiaNm0KABg8eLDiFHn+cbZt2xZ16tRBTEwMWrduDSMjI8Xn8t9rJH19fWFgYFDg+L28vGBpaYkHDx4U+ViJiEoSE0miMmz//v1wdnZGixYtijT+iy++wLRp09CoUSMsWrQIbdq0QXBwMHx8fAqMvXHjBnr37o2PP/4YCxYsgKWlJQYNGoQrV64AALy9vbFo0SIAQN++fbF582YsXry4WPFfuXIFXbp0QVZWFmbOnIkFCxagW7duOHny5BvX++OPP+Dl5YXk5GQEBQXB398ff/31Fzw8PHDnzp0C4/v06YP09HQEBwejT58+CAsLw4wZM4ocp7e3NyRJwq5duxRtW7duRa1atdCoUaMC42/duoU9e/agS5cuWLhwIQICAnDp0iW0adNGkdS5urpi5syZAIBhw4Zh8+bN2Lx5M1q3bq3YzuPHj9GxY0c0aNAAixcvRrt27QqNb8mSJahUqRJ8fX2Rl5cHAFi9ejUOHz6MpUuXws7OrsjHSkRUouREVCY9ffpUDkDevXv3Io2PjY2VA5B/8cUXSu0TJkyQA5AfOXJE0ebg4CAHII+MjFS0JScny/X19eXjx49XtN2+fVsOQD5v3jylbfr6+sodHBwKxDB9+nT5q18rixYtkgOQP3r06LVx5+9jw4YNirYGDRrIra2t5Y8fP1a0XbhwQa6lpSUfOHBggf0NGTJEaZs9e/aUV6hQ4bX7fPU4jI2N5XK5XN67d295+/bt5XK5XJ6Xlye3tbWVz5gxo9DPIDMzU56Xl1fgOPT19eUzZ85UtJ05c6bAseVr06aNHIB81apVhfa1adNGqe3QoUNyAPLZs2fLb926JTcxMZH36NHjrcdIRKROrEgSlVFpaWkAAFNT0yKN/+233wAA/v7+Su3jx48HgALXUrq5uaFVq1aK95UqVULNmjVx69YtlWP+r/xrK/fu3QuZTFakdR4+fIjY2FgMGjQIVlZWivZ69erh448/Vhznq0aMGKH0vlWrVnj8+LHiMyyKfv364dixY0hMTMSRI0eQmJhY6Glt4OV1lVpaL78+8/Ly8PjxY8Vp+3PnzhV5n/r6+hg8eHCRxnbo0AHDhw/HzJkz4e3tDQMDA6xevbrI+yIiUgcmkkRllJmZGQAgPT29SOP/+ecfaGlpwcXFRand1tYWFhYW+Oeff5Taq1atWmAblpaWePLkiYoRF/TZZ5/Bw8MDX3zxBWxsbODj44Pt27e/ManMj7NmzZoF+lxdXfHvv//i2bNnSu3/PRZLS0sAKNaxdOrUCaampvj555+xZcsWNG3atMBnmU8mk2HRokWoXr069PX1UbFiRVSqVAkXL17E06dPi7xPe3v7Yt1YM3/+fFhZWSE2NhahoaGwtrYu8rpEROrARJKojDIzM4OdnR0uX75crPX+e7PL62hraxfaLpfLVd5H/vV7+QwNDREZGYk//vgDAwYMwMWLF/HZZ5/h448/LjD2XbzLseTT19eHt7c3Nm7ciN27d7+2GgkAc+fOhb+/P1q3bo0ff/wRhw4dQnh4OGrXrl3kyivw8vMpjvPnzyM5ORkAcOnSpWKtS0SkDkwkicqwLl264ObNm4iKinrrWAcHB8hkMly/fl2pPSkpCampqYo7sEuCpaWl0h3O+f5b9QQALS0ttG/fHgsXLsTVq1cxZ84cHDlyBEePHi102/lxxsfHF+i7du0aKlasCGNj43c7gNfo168fzp8/j/T09EJvUMq3Y8cOtGvXDj/88AN8fHzQoUMHeHp6FvhMiprUF8WzZ88wePBguLm5YdiwYQgJCcGZM2dKbPtERKpgIklUhk2cOBHGxsb44osvkJSUVKD/5s2bWLJkCYCXp2YBFLizeuHChQCAzp07l1hc1apVw9OnT3Hx4kVF28OHD7F7926lcSkpKQXWzZ+Y+79TEuWrXLkyGjRogI0bNyolZpcvX8bhw4cVx6kO7dq1w6xZs7Bs2TLY2tq+dpy2tnaBaucvv/yC+/fvK7XlJ7yFJd3FNWnSJCQkJGDjxo1YuHAhHB0d4evr+9rPkYioNHBCcqIyrFq1ati6dSs+++wzuLq6Kj3Z5q+//sIvv/yCQYMGAQDq168PX19frFmzBqmpqWjTpg1Onz6NjRs3okePHq+dWkYVPj4+mDRpEnr27InRo0fj+fPnWLlyJWrUqKF0s8nMmTMRGRmJzp07w8HBAcnJyVixYgU++OADtGzZ8rXbnzdvHjp27Ah3d3cMHToUL168wNKlS2Fubo6goKASO47/0tLSwpQpU946rkuXLpg5cyYGDx6MFi1a4NKlS9iyZQucnZ2VxlWrVg0WFhZYtWoVTE1NYWxsjGbNmsHJyalYcR05cgQrVqzA9OnTFdMRbdiwAW3btsXUqVMREhJSrO0REZUUViSJyrhu3brh4sWL6N27N/bu3Qs/Pz9MnjwZd+7cwYIFCxAaGqoYu27dOsyYMQNnzpzB2LFjceTIEQQGBmLbtm0lGlOFChWwe/duGBkZYeLEidi4cSOCg4PRtWvXArFXrVoV69evh5+fH5YvX47WrVvjyJEjMDc3f+32PT09cfDgQVSoUAHTpk3D/Pnz0bx5c5w8ebLYSZg6fPPNNxg/fjwOHTqEMWPG4Ny5c/j1119RpUoVpXG6urrYuHEjtLW1MWLECPTt2xfHjx8v1r7S09MxZMgQNGzYEN9++62ivVWrVhgzZgwWLFiAU6dOlchxEREVlyQvztXoRERERET/w4okEREREamEiSQRERERqYSJJBERERGphIkkEREREamEiSQRERERqYSJJBERERGphIkkEREREamkXD7Zxtn
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 800x600 with 2 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Смещение: 0.8686998314031272\n",
|
|||
|
"Дисперсия: 0.003095104102985812\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Конвейер для случайного леса\n",
|
|||
|
"pipeline_ranfor = Pipeline([\n",
|
|||
|
" ('preprocessing', preprocessing),\n",
|
|||
|
" ('classifier', RandomForestClassifier())\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Определение сетки гиперпараметров\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" #Количество деревьев в лесу\n",
|
|||
|
" 'classifier__n_estimators': [50, 100, 200],\n",
|
|||
|
" #Максимальная глубина дерева\n",
|
|||
|
" 'classifier__max_depth': [10, 20, 30],\n",
|
|||
|
" #Минимальное количество образцов для листового узла\n",
|
|||
|
" 'classifier__min_samples_leaf': [1, 2, 4]\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Создание объекта GridSearchCV\n",
|
|||
|
"grid_search = GridSearchCV(pipeline_ranfor, param_grid, cv=5, scoring='accuracy', n_jobs=-1)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели с перебором гиперпараметров\n",
|
|||
|
"grid_search.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"print(\"Лучшие гиперпараметры: \", grid_search.best_params_)\n",
|
|||
|
"\n",
|
|||
|
"# Лучшая модель случайного леса\n",
|
|||
|
"best_model = grid_search.best_estimator_\n",
|
|||
|
"\n",
|
|||
|
"# Использование и оценка лучшей модели\n",
|
|||
|
"y_pred_proba = best_model.predict_proba(X_test)[:, 1]\n",
|
|||
|
"print(f'ROC у метода случайного леса = {roc_auc_score(y_test, y_pred_proba)}')\n",
|
|||
|
"\n",
|
|||
|
"y_pred = best_model.predict(X_test)\n",
|
|||
|
"print(f'Точность = {accuracy_score(y_test, y_pred)}')\n",
|
|||
|
"\n",
|
|||
|
"fpr, tpr, _ = metrics.roc_curve(y_test, y_pred_proba)\n",
|
|||
|
"\n",
|
|||
|
"# построение ROC кривой\n",
|
|||
|
"plt.plot(fpr, tpr)\n",
|
|||
|
"plt.ylabel('True Positive Rate')\n",
|
|||
|
"plt.xlabel('False Positive Rate')\n",
|
|||
|
"plt.show()\n",
|
|||
|
"\n",
|
|||
|
"# Построение матрицы ошибок\n",
|
|||
|
"conf_matrix = confusion_matrix(y_test, y_pred)\n",
|
|||
|
"\n",
|
|||
|
"# Визуализация матрицы ошибок\n",
|
|||
|
"plt.figure(figsize=(8, 6))\n",
|
|||
|
"sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', \n",
|
|||
|
" xticklabels=['Предсказанный \"безопасный\"', 'Предсказанный \"опасный\"'], \n",
|
|||
|
" yticklabels=['Действительно \"безопасный\"', 'Действительно \"опасный\"'])\n",
|
|||
|
"plt.title('Confusion Matrix')\n",
|
|||
|
"plt.ylabel('Actual')\n",
|
|||
|
"plt.xlabel('Predicted')\n",
|
|||
|
"plt.show()\n",
|
|||
|
"\n",
|
|||
|
"# Оценка дисперсии и смещения\n",
|
|||
|
"cv_results = grid_search.cv_results_\n",
|
|||
|
"mean_test_score = cv_results['mean_test_score']\n",
|
|||
|
"std_test_score = cv_results['std_test_score']\n",
|
|||
|
"\n",
|
|||
|
"print(f\"Смещение: {mean_test_score.mean()}\")\n",
|
|||
|
"print(f\"Дисперсия: {std_test_score.mean()}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Градиентный бустинг"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Лучшие гиперпараметры: {'classifier__learning_rate': 0.1, 'classifier__max_depth': 3, 'classifier__n_estimators': 300, 'classifier__subsample': 0.5}\n",
|
|||
|
"ROC у метода градиентного спуска = 0.9012421336337971\n",
|
|||
|
"Точность = 0.872737556561086\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAGwCAYAAABVdURTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABC8klEQVR4nO3dfXzO9eLH8fe12a4N29Da5maam9zlZm7iIIlWVEecTuWUg9TRKTd12lEhLCWcinSinJRUv4rqqJxoDoqDdMisFCY3i7CxsDtsdl3f3x/LVzuGXXNd1/e6rr2ej8f1eHy+332/1/XeF13vvrc2wzAMAQAABIggqwMAAAC4E+UGAAAEFMoNAAAIKJQbAAAQUCg3AAAgoFBuAABAQKHcAACAgFLN6gDe5nQ6dfDgQUVERMhms1kdBwAAVIBhGMrPz1e9evUUFHThfTNVrtwcPHhQ8fHxVscAAACVsH//fjVo0OCCy1S5chMRESGpdONERkZanAYAAFREXl6e4uPjze/xC6ly5ebMoajIyEjKDQAAfqYip5RwQjEAAAgolBsAABBQKDcAACCgUG4AAEBAodwAAICAQrkBAAABhXIDAAACCuUGAAAEFMoNAAAIKJQbAAAQUCwtN//5z3/Ur18/1atXTzabTR9//PFF11m9erU6dOggu92upk2basGCBR7PCQAA/Iel5aawsFDt2rXTnDlzKrT83r17dcstt6hXr15KT0/XX/7yF/3pT3/S8uXLPZwUAAD4C0sfnHnTTTfppptuqvDyc+fOVaNGjTRjxgxJUsuWLbVu3Tq98MIL6tOnj6diAuU6nH9KxSVOq2MAgM8JrRakmIgwyz7fr54KvmHDBiUlJZWZ16dPH/3lL3857zpFRUUqKioyp/Py8jwVDwHktMOpBesz9dl3h1TiNBQeElzm5//de9SiZADg+zo0rKXFI7pb9vl+VW6ysrIUGxtbZl5sbKzy8vJ08uRJhYeHn7POtGnTNHnyZG9FhJ/bkZWnvrPWurSOvRrn5QPAr4UEW/vfRb8qN5Uxbtw4JScnm9N5eXmKj4+3MBF80f6jJ9Tj2S/K/dldnRuqQ8NaCg8tu/cmNjJMVyfU8UY8AIAL/KrcxMXFKTs7u8y87OxsRUZGlrvXRpLsdrvsdrs34sEPFZc41ePZz5WdV1Rmfv1a4fr3I9eqht2v/okAAORn5aZr165atmxZmXkrVqxQ165dLUoEf3fDC2vKFJuklrGaN6SjbDabhakAAJfC0nJTUFCgXbt2mdN79+5Venq66tSpo4YNG2rcuHE6cOCA3nrrLUnSAw88oNmzZ+uxxx7Tvffeq88//1zvv/++li5datWvAD/2RcZh/fjzCXM6Y0pf2asFX2ANAIA/sPSMn6+//lrt27dX+/btJUnJyclq3769Jk2aJEk6dOiQ9u3bZy7fqFEjLV26VCtWrFC7du00Y8YMvfbaa1wGDpdt/vGYhr2xyZz+bnIfig0ABAibYRiG1SG8KS8vT1FRUcrNzVVkZKTVceBlmTmFuu751WXmvXVvZ13b7HJrAgEAKsSV72+uYUWV8c3+4+cUm+E9GlFsACDA+NUJxcCl6D9nvTm+ql6kFo/oxqEoAAhAlBtUCb9+TMJNreP0yh87WpgGAOBJHJZCwDtZ7NCAX+21mX5bWwvTAAA8jT03CGh3zt2gjZllnwMVVT3EojQAAG9gzw0CVkFRyTnF5pOR1j3IDQDgHey5QcC691f3sfkm5UZFhbPHBgCqAvbcICDtPlJQZq8NxQYAqg7KDQJOUYlD189YY06n/qWHhWkAAN5GuUHAmf352eeVtakfpRZx3IkaAKoSyg0CyoL1e/XSr8rNklGcQAwAVQ0nFCMg7D96Qj2e/aLMvGd/31Y2m82iRAAAq7DnBgEhff/xMtPjbmqhOzo1sCYMAMBS7LlBQHjsw28lSS3iIpT6l2stTgMAsBJ7buD3Hnh7s06edkiS6kaFWZwGAGA1yg382pe7cpT6fZY5PXcwD8QEgKqOcgO/dbSwWHe/9l9z+t+PXCt7tWALEwEAfAHlBn6rw9MrzHFKv1ZqFhthYRoAgK+g3MDvNb68hoZ1b2R1DACAj6DcwO/sP3pCCWOXmtMf/LmrhWkAAL6GcgO/k/x+epnpOjVCrQkCAPBJlBv4leISpzZlHjOndzzdl7sQAwDKoNzAr+SePG2O1z7WS2EhXB0FACiLcgO/8t2BXHMcX6e6hUkAAL6KcgO/MmzBJqsjAAB8HOUGfuWyX04eHtY9wdogAACfRbmB3ygoKtHPhcWSpP6J9S1OAwDwVZQb+I2l3x40xy3iuBsxAKB8lBv4jaOFpVdKBdnEVVIAgPOi3MAv5J06rb+l7pAkdW8abXEaAIAvo9zA5x3KPam2T/7bnL7xqjgL0wAAfB3lBj6v67TPzXHnhDoa/JsrLEwDAPB1lBv4tLc3ZJrjGqHBeu/+31gXBgDgF6pZHQA4n9fW7tGUpdvN6a8n3KDgIJ4jBQC4MPbcwCcVFpWUKTbvDf+NwkO5QgoAcHHsuYFP+ve2LHO86P7fqEvjyyxMAwDwJ+y5gU965pe9NmEhQRQbAIBLKDfwOUUlDuUUlD5mIeGyGhanAQD4G8oNfM7/fbXPHL/yx44WJgEA+CPKDXzO059uM8eNotlzAwBwDeUGPmXLvmPmeHTvphYmAQD4K8oNfMrvXv7SHI/sRbkBALiOcgOfkfjU2edHPdqnOU/+BgBUCuUGPuFYYbGOnzhtTo+4romFaQAA/oxyA5/Q/ukV5viHZ26SzcZjFgAAlUO5geUWbjx76XdifC2FBPPXEgBQeXyLwHKT/3X20u+PRnSzMAkAIBBQbmC5OjVCJUkP9GzC4SgAwCWj3MBSye+n68Dxk5Kk3i1iLE4DAAgElBtY5ouMw1qcdsCcbh4bYWEaAECgoNzAMsPe2GSONz5xvaKqh1iYBgAQKCg3sMTJYoc5vqtzQ8VEhFmYBgAQSCg3sMQTH209O76lpYVJAACBhnIDS2zY87Mkqaa9mmraq1mcBgAQSCg38Lqfjp3QodxTkqSPR3a3OA0AINBQbuB11/ztC3PcOLqGhUkAAIGIcgOvOu1wmuPOjeooKIib9gEA3MvycjNnzhwlJCQoLCxMXbp00caNGy+4/KxZs9S8eXOFh4crPj5ejzzyiE6dOuWltLhUa384Yo5fG9rJwiQAgEBlablZtGiRkpOTlZKSorS0NLVr1059+vTR4cOHy13+3Xff1dixY5WSkqLt27fr9ddf16JFizR+/HgvJ0dlFJU4dO+Cr83pGqGcSAwAcD9Ly83MmTM1fPhwDRs2TK1atdLcuXNVvXp1zZ8/v9zlv/zyS3Xv3l133323EhISdOONN+quu+664N6eoqIi5eXllXnBGuP+efby78f7tlAwh6QAAB5gWbkpLi7W5s2blZSUdDZMUJCSkpK0YcOGctfp1q2bNm/ebJaZPXv2aNmyZbr55pvP+znTpk1TVFSU+YqPj3fvLwKXRYWH6IGeja2OAQAIUJYdF8jJyZHD4VBsbGyZ+bGxsdqxY0e569x9993KycnRNddcI8MwVFJSogceeOCCh6XGjRun5ORkczovL4+CY4GDx09q8ZbS50gN7XoFT/8GAHiM5ScUu2L16tWaOnWqXn75ZaWlpWnx4sVaunSpnn766fOuY7fbFRkZWeYF7+s2/XNzfHWjOhYmAQAEOsv23ERHRys4OFjZ2dll5mdnZysuLq7cdSZOnKjBgwfrT3/6kySpTZs2Kiws1P33368nnnhCQUF+1dWqlMsj7DqSX6QWcRHqceXlVscBAAQwy9pAaGioOnbsqFWrVpnznE6nVq1apa5du5a7zokTJ84pMMHBwZIkwzA8FxaX5JP0AzqSXyRJmnlnorVhAAABz9JrcZOTkzV06FB16tRJnTt31qxZs1RYWKhhw4ZJkoYMGaL69etr2rRpkqR+/fpp5syZat++vbp06aJdu3Zp4sSJ6tevn1ly4FtOFJfo4YXp5nR8nXDrwgAAqgRLy83AgQN15MgRTZo0SVlZWUpMTFRqaqp5kvG+ffvK7Km
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Матрица ошибок:\n",
|
|||
|
"[[1326 400]\n",
|
|||
|
" [ 50 1760]]\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAApIAAAIjCAYAAACwHvu2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB2r0lEQVR4nO3dd1gUV9sG8HtpS19ABcQogljAiiWKWCOR2BUTRY1iSYwGK4pKYm8k2BC7sWCNMRZiiUaiURQRFUWNIrEGG2BEBERpu98ffuzrigVWloPs/cs118WeOTPz7BLXx+ecOSNRKBQKEBEREREVkY7oAIiIiIjow8REkoiIiIjUwkSSiIiIiNTCRJKIiIiI1MJEkoiIiIjUwkSSiIiIiNTCRJKIiIiI1MJEkoiIiIjUwkSSiIiIiNTCRJKI3uratWto3749ZDIZJBIJwsLCivX8t2/fhkQiQWhoaLGe90PWpk0btGnTRnQYRETvxESS6ANw48YNfPPNN3B0dIShoSHMzc3h7u6OxYsX49mzZxq9to+PDy5duoQ5c+Zg06ZNaNy4sUavV5IGDhwIiUQCc3Pz136O165dg0QigUQiwfz584t8/vv372P69OmIjY0thmiJiEofPdEBENHb7d+/H1988QWkUikGDBiAOnXqIDs7GydOnIC/vz8uX76M1atXa+Taz549Q1RUFL7//nuMGDFCI9ewt7fHs2fPoK+vr5Hzv4uenh4yMzOxd+9e9OrVS2Xfli1bYGhoiOfPn6t17vv372PGjBmoWrUqGjRoUOjjDh06pNb1iIhKGhNJolLs1q1b8Pb2hr29PY4cOYKKFSsq9/n6+uL69evYv3+/xq7/8OFDAICFhYXGriGRSGBoaKix87+LVCqFu7s7fv755wKJ5NatW9GpUyfs3LmzRGLJzMyEsbExDAwMSuR6RETvi0PbRKVYUFAQMjIysHbtWpUkMp+TkxNGjx6tfJ2bm4tZs2ahWrVqkEqlqFq1Kr777jtkZWWpHFe1alV07twZJ06cwMcffwxDQ0M4Ojpi48aNyj7Tp0+Hvb09AMDf3x8SiQRVq1YF8GJIOP/nl02fPh0SiUSlLTw8HC1atICFhQVMTU1Rs2ZNfPfdd8r9b5ojeeTIEbRs2RImJiawsLBAt27dEBcX99rrXb9+HQMHDoSFhQVkMhkGDRqEzMzMN3+wr+jbty8OHDiA1NRUZduZM2dw7do19O3bt0D/lJQUjB8/HnXr1oWpqSnMzc3RoUMHXLhwQdnn6NGjaNKkCQBg0KBByiHy/PfZpk0b1KlTBzExMWjVqhWMjY2Vn8urcyR9fHxgaGhY4P17enrC0tIS9+/fL/R7JSIqTkwkiUqxvXv3wtHREc2bNy9U/6+++gpTp05Fw4YNsWjRIrRu3RqBgYHw9vYu0Pf69ev4/PPP8emnn2LBggWwtLTEwIEDcfnyZQCAl5cXFi1aBADo06cPNm3ahODg4CLFf/nyZXTu3BlZWVmYOXMmFixYgK5duyIyMvKtx/3555/w9PREcnIypk+fDj8/P5w8eRLu7u64fft2gf69evVCeno6AgMD0atXL4SGhmLGjBmFjtPLywsSiQS7du1Stm3duhW1atVCw4YNC/S/efMmwsLC0LlzZyxcuBD+/v64dOkSWrdurUzqnJ2dMXPmTADA0KFDsWnTJmzatAmtWrVSnufRo0fo0KEDGjRogODgYLRt2/a18S1evBgVKlSAj48P8vLyAACrVq3CoUOHsGTJEtjZ2RX6vRIRFSsFEZVKT548UQBQdOvWrVD9Y2NjFQAUX331lUr7+PHjFQAUR44cUbbZ29srACgiIiKUbcnJyQqpVKoYN26csu3WrVsKAIp58+apnNPHx0dhb29fIIZp06YpXv5aWbRokQKA4uHDh2+MO/8a69evV7Y1aNBAYW1trXj06JGy7cKFCwodHR3FgAEDClxv8ODBKufs0aOHoly5cm+85svvw8TERKFQKBSff/65ol27dgqFQqHIy8tT2NraKmbMmPHaz+D58+eKvLy8Au9DKpUqZs6cqWw7c+ZMgfeWr3Xr1goAipUrV752X+vWrVXa/vjjDwUAxezZsxU3b95UmJqaKrp37/7O90hEpEmsSBKVUmlpaQAAMzOzQvX//fffAQB+fn4q7ePGjQOAAnMpXVxc0LJlS+XrChUqoGbNmrh586baMb8qf27lb7/9BrlcXqhjHjx4gNjYWAwcOBBWVlbK9nr16uHTTz9Vvs+XDRs2TOV1y5Yt8ejRI+VnWBh9+/bF0aNHkZiYiCNHjiAxMfG1w9rAi3mVOjovvj7z8vLw6NEj5bD9uXPnCn1NqVSKQYMGFapv+/bt8c0332DmzJnw8vKCoaEhVq1aVehrERFpAhNJolLK3NwcAJCenl6o/v/++y90dHTg5OSk0m5rawsLCwv8+++/Ku1VqlQpcA5LS0s8fvxYzYgL6t27N9zd3fHVV1/BxsYG3t7e2L59+1uTyvw4a9asWWCfs7Mz/vvvPzx9+lSl/dX3YmlpCQBFei8dO3aEmZkZfvnlF2zZsgVNmjQp8Fnmk8vlWLRoEapXrw6pVIry5cujQoUKuHjxIp48eVLoa1aqVKlIN9bMnz8fVlZWiI2NRUhICKytrQt9LBGRJjCRJCqlzM3NYWdnh7///rtIx716s8ub6OrqvrZdoVCofY38+Xv5jIyMEBERgT///BP9+/fHxYsX0bt3b3z66acF+r6P93kv+aRSKby8vLBhwwbs3r37jdVIAJg7dy78/PzQqlUrbN68GX/88QfCw8NRu3btQldegRefT1GcP38eycnJAIBLly4V6VgiIk1gIklUinXu3Bk3btxAVFTUO/va29tDLpfj2rVrKu1JSUlITU1V3oFdHCwtLVXucM73atUTAHR0dNCuXTssXLgQV65cwZw5c3DkyBH89ddfrz13fpzx8fEF9l29ehXly5eHiYnJ+72BN+jbty/Onz+P9PT0196glG/Hjh1o27Yt1q5dC29vb7Rv3x4eHh4FPpPCJvWF8fTpUwwaNAguLi4YOnQogoKCcObMmWI7PxGROphIEpViEyZMgImJCb766iskJSUV2H/jxg0sXrwYwIuhWQAF7qxeuHAhAKBTp07FFle1atXw5MkTXLx4Udn24MED7N69W6VfSkpKgWPzF+Z+dUmifBUrVkSDBg2wYcMGlcTs77//xqFDh5TvUxPatm2LWbNmYenSpbC1tX1jP11d3QLVzl9//RX37t1TactPeF+XdBfVxIkTkZCQgA0bNmDhwoWoWrUqfHx83vg5EhGVBC5ITlSKVatWDVu3bkXv3r3h7Oys8mSbkydP4tdff8XAgQMBAPXr14ePjw9Wr16N1NRUtG7dGqdPn8aGDRvQvXv3Ny4tow5vb29MnDgRPXr0wKhRo5CZmYkVK1agRo0aKjebzJw5ExEREejUqRPs7e2RnJyM5cuX46OPPkKLFi3eeP558+ahQ4cOcHNzw5AhQ/Ds2TMsWbIEMpkM06dPL7b38SodHR1Mnjz5nf06d+6MmTNnYtCgQWjevDkuXbqELVu2wNHRUaVftWrVYGFhgZUrV8LMzAwmJiZo2rQpHBwcihTXkSNHsHz5ckybNk25HNH69evRpk0bTJkyBUFBQUU6HxFRcWFFkqiU69q1Ky5evIjPP/8cv/32G3x9fTFp0iTcvn0bCxYsQEhIiLLvmjVrMGPGDJw5cwZjxozBkSNHEBAQgG3bthVrTOXKlcPu3bthbGyMCRMmYMOGDQgMDESXLl0KxF6lShWsW7cOvr6+WLZsGVq1aoUjR45AJpO98fweHh44ePAgypUrh6lTp2L+/Plo1qwZIiMji5yEacJ3332HcePG4Y8//sDo0aNx7tw57N+/H5UrV1bpp6+vjw0bNkBXVxfDhg1Dnz59cOzYsSJdKz09HYMHD4arqyu+//57ZXvLli0xevRoLFiwAKdOnSqW90VEVFQSRVFmoxMRERER/T9WJImIiIhILUwkiYiIiEgtTCSJiIiISC1MJImIiIhILUwkiYiIiEgtTCSJiIiISC1MJImIiIhILWX
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 800x600 with 2 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Смещение: 0.8811650848575816\n",
|
|||
|
"Дисперсия: 0.008658656436943876\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Конвейер\n",
|
|||
|
"pipeline_grad = Pipeline([\n",
|
|||
|
" ('preprocessing', preprocessing),\n",
|
|||
|
" ('classifier', GradientBoostingClassifier())\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Определение сетки гиперпараметров\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" 'classifier__n_estimators': [100, 200, 300],\n",
|
|||
|
" #Скорость обучения\n",
|
|||
|
" 'classifier__learning_rate': [0.1, 0.2],\n",
|
|||
|
" #Максимальная глубина дерева\n",
|
|||
|
" 'classifier__max_depth': [3, 5, 7],\n",
|
|||
|
" 'classifier__subsample': [0.1, 0.5, 1.0],\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Создание объекта GridSearchCV\n",
|
|||
|
"grid_search = GridSearchCV(pipeline_grad, param_grid, cv=2, scoring='roc_auc', n_jobs=-1)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели с перебором гиперпараметров\n",
|
|||
|
"grid_search.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"print(\"Лучшие гиперпараметры: \", grid_search.best_params_)\n",
|
|||
|
"\n",
|
|||
|
"# Лучшая модель\n",
|
|||
|
"best_model = grid_search.best_estimator_\n",
|
|||
|
"\n",
|
|||
|
"# Использование и оценка лучшей модели\n",
|
|||
|
"y_pred_proba = best_model.predict_proba(X_test)[:, 1]\n",
|
|||
|
"print(f'ROC у метода градиентного спуска = {roc_auc_score(y_test, y_pred_proba)}')\n",
|
|||
|
"\n",
|
|||
|
"y_pred = best_model.predict(X_test)\n",
|
|||
|
"print(f'Точность = {accuracy_score(y_test, y_pred)}')\n",
|
|||
|
"\n",
|
|||
|
"fpr, tpr, _ = metrics.roc_curve(y_test, y_pred_proba)\n",
|
|||
|
"\n",
|
|||
|
"# построение ROC кривой\n",
|
|||
|
"plt.plot(fpr, tpr)\n",
|
|||
|
"plt.ylabel('True Positive Rate')\n",
|
|||
|
"plt.xlabel('False Positive Rate')\n",
|
|||
|
"plt.show()\n",
|
|||
|
"\n",
|
|||
|
"# Построение матрицы ошибок\n",
|
|||
|
"conf_matrix = confusion_matrix(y_test, y_pred)\n",
|
|||
|
"\n",
|
|||
|
"# Визуализация матрицы ошибок\n",
|
|||
|
"plt.figure(figsize=(8, 6))\n",
|
|||
|
"sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', \n",
|
|||
|
" xticklabels=['Предсказанный \"безопасный\"', 'Предсказанный \"опасный\"'], \n",
|
|||
|
" yticklabels=['Действительно \"безопасный\"', 'Действительно \"опасный\"'])\n",
|
|||
|
"plt.title('Confusion Matrix')\n",
|
|||
|
"plt.ylabel('Actual')\n",
|
|||
|
"plt.xlabel('Predicted')\n",
|
|||
|
"plt.show()\n",
|
|||
|
"\n",
|
|||
|
"# Оценка дисперсии и смещения\n",
|
|||
|
"cv_results = grid_search.cv_results_\n",
|
|||
|
"mean_test_score = cv_results['mean_test_score']\n",
|
|||
|
"std_test_score = cv_results['std_test_score']\n",
|
|||
|
"\n",
|
|||
|
"print(f\"Смещение: {mean_test_score.mean()}\")\n",
|
|||
|
"print(f\"Дисперсия: {std_test_score.mean()}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"**Вывод**:\n",
|
|||
|
"\n",
|
|||
|
"Все модели классификации показали хорошие результаты, но лучший показатель точности у случайного леса. При этом все рассмотренные модели немного не дотянули до показателя точности в 90%. Дополнительая настройка гиперпараметров могла бы приблизить значение оценки ещё ближе к 90% "
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "aimenv",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.12.5"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|