484 lines
443 KiB
Plaintext
484 lines
443 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"1. Бизнес-цели для решения задач\n",
|
|||
|
"регрессии - предсказание цены страховки (целевой столбец - charges)\n",
|
|||
|
"классификации - определить цену страховки в категориальных параметрах (низкая, высокая, нормальная) на основе средней цены\n",
|
|||
|
"\n",
|
|||
|
"2. Достижимый уровень качества\n",
|
|||
|
"данный датасет не требует от человека применения сложных средств и методов для разметки данных. Также я буду исходить из идеи, что все необходимые сведения предоставлены. Исхрдя из этого можно сказать, что модель может достичь \"хорошего\" уровня (цитата из презентации)\n",
|
|||
|
"\n",
|
|||
|
"3. Выбор ориентира\n",
|
|||
|
"регрессия - средняя стоимость страховки\n",
|
|||
|
"классификация - наиболее часто встречающаяся категория"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Подготовка данных (взято из 3 ЛР)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"2772\n",
|
|||
|
"было 2772\n",
|
|||
|
"age 39.10966810966811 14.081459420836477\n",
|
|||
|
"bmi 30.70134920634921 6.1294486949652205\n",
|
|||
|
"children 1.1026753434562546 1.2157555494600176\n",
|
|||
|
"charges 13325.498588795157 12200.175109274192\n",
|
|||
|
"стало 2710\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"df = pd.read_csv(\"../dataset.csv\")\n",
|
|||
|
"print(df.shape[0])\n",
|
|||
|
"\n",
|
|||
|
"print(\"было \", df.shape[0])\n",
|
|||
|
"for column in df.select_dtypes(include=['int', 'float']).columns:\n",
|
|||
|
" mean = df[column].mean()\n",
|
|||
|
" std_dev = df[column].std()\n",
|
|||
|
" print(column, mean, std_dev)\n",
|
|||
|
" \n",
|
|||
|
" lower_bound = mean - 3 * std_dev\n",
|
|||
|
" upper_bound = mean + 3 * std_dev\n",
|
|||
|
" \n",
|
|||
|
" df = df[(df[column] <= upper_bound) & (df[column] >= lower_bound)]\n",
|
|||
|
" \n",
|
|||
|
"print(\"стало \", df.shape[0])\n",
|
|||
|
"df = pd.get_dummies(df, columns=['smoker', 'sex', 'region'])\n",
|
|||
|
"df = df.reset_index(drop=True)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"5-6. Выбор моделей и построение конвеера"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Best parameters for Linear Regression: {}\n",
|
|||
|
"Best parameters for Random Forest Regressor: {'model__max_depth': None, 'model__n_estimators': 200}\n",
|
|||
|
"Best parameters for Gradient Boosting Regressor: {'model__learning_rate': 0.1, 'model__max_depth': 7, 'model__n_estimators': 200}\n",
|
|||
|
"Model: Linear Regression\n",
|
|||
|
"Model: Random Forest Regressor\n",
|
|||
|
"Model: Gradient Boosting Regressor\n",
|
|||
|
"Results for Linear Regression:\n",
|
|||
|
"MSE_train: 34444659.26088208\n",
|
|||
|
"MSE_test: 34814365.10649261\n",
|
|||
|
"R2_train: 0.7519227670167513\n",
|
|||
|
"R2_test: 0.7526515685852002\n",
|
|||
|
"MAE_train: 4063.4298138535955\n",
|
|||
|
"MAE_test: 4177.000255554095\n",
|
|||
|
"Results for Random Forest Regressor:\n",
|
|||
|
"MSE_train: 1068690.9196129853\n",
|
|||
|
"MSE_test: 6893069.254237878\n",
|
|||
|
"R2_train: 0.9923030771114929\n",
|
|||
|
"R2_test: 0.9510262541782973\n",
|
|||
|
"MAE_train: 458.22401190138396\n",
|
|||
|
"MAE_test: 1239.9530445735054\n",
|
|||
|
"Results for Gradient Boosting Regressor:\n",
|
|||
|
"MSE_train: 314879.3352948071\n",
|
|||
|
"MSE_test: 5789682.017738877\n",
|
|||
|
"R2_train: 0.9977321768918687\n",
|
|||
|
"R2_test: 0.9588655785881008\n",
|
|||
|
"MAE_train: 241.34572975609007\n",
|
|||
|
"MAE_test: 870.203742123879\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.discriminant_analysis import StandardScaler\n",
|
|||
|
"from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor\n",
|
|||
|
"from sklearn.linear_model import LinearRegression\n",
|
|||
|
"from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
|
|||
|
"from sklearn.model_selection import GridSearchCV, train_test_split\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"\n",
|
|||
|
"X_reg = df.drop(columns=['charges'])\n",
|
|||
|
"y_reg = df['charges']\n",
|
|||
|
"\n",
|
|||
|
"X_train_reg, X_test_reg, y_train_reg, y_test_reg = train_test_split(X_reg, y_reg, test_size=0.2, random_state=13)\n",
|
|||
|
"\n",
|
|||
|
"models_reg = {\n",
|
|||
|
" 'Linear Regression': LinearRegression(),\n",
|
|||
|
" 'Random Forest Regressor': RandomForestRegressor(random_state=42),\n",
|
|||
|
" 'Gradient Boosting Regressor': GradientBoostingRegressor(random_state=42)\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"pipelines_reg = {}\n",
|
|||
|
"for name, model in models_reg.items():\n",
|
|||
|
" pipelines_reg[name] = Pipeline([\n",
|
|||
|
" ('scaler', StandardScaler()),\n",
|
|||
|
" ('model', model)\n",
|
|||
|
" ])\n",
|
|||
|
"\n",
|
|||
|
"param_grids_reg = {\n",
|
|||
|
" 'Linear Regression': {},\n",
|
|||
|
" 'Random Forest Regressor': {\n",
|
|||
|
" 'model__n_estimators': [100, 200, 300],\n",
|
|||
|
" 'model__max_depth': [None, 10, 20, 30]\n",
|
|||
|
" },\n",
|
|||
|
" 'Gradient Boosting Regressor': {\n",
|
|||
|
" 'model__n_estimators': [100, 200, 300],\n",
|
|||
|
" 'model__learning_rate': [0.01, 0.1, 0.2],\n",
|
|||
|
" 'model__max_depth': [3, 5, 7]\n",
|
|||
|
" }\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"best_models_reg = {}\n",
|
|||
|
"for name, pipeline in pipelines_reg.items():\n",
|
|||
|
" grid_search = GridSearchCV(pipeline, param_grids_reg[name], cv=5, scoring='neg_mean_squared_error')\n",
|
|||
|
" grid_search.fit(X_train_reg, y_train_reg)\n",
|
|||
|
" best_models_reg[name] = grid_search.best_estimator_\n",
|
|||
|
" print(f'Best parameters for {name}: {grid_search.best_params_}')\n",
|
|||
|
"\n",
|
|||
|
"results_reg = {}\n",
|
|||
|
"for model_name in best_models_reg.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" model_pipeline = best_models_reg[model_name]\n",
|
|||
|
"\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train_reg)\n",
|
|||
|
" y_test_predict = model_pipeline.predict(X_test_reg)\n",
|
|||
|
"\n",
|
|||
|
" results_reg[model_name] = {\n",
|
|||
|
" \"pipeline\": model_pipeline,\n",
|
|||
|
" \"preds_train\": y_train_predict,\n",
|
|||
|
" \"preds_test\": y_test_predict,\n",
|
|||
|
" \"MSE_train\": mean_squared_error(y_train_reg, y_train_predict),\n",
|
|||
|
" \"MSE_test\": mean_squared_error(y_test_reg, y_test_predict),\n",
|
|||
|
" \"R2_train\": r2_score(y_train_reg, y_train_predict),\n",
|
|||
|
" \"R2_test\": r2_score(y_test_reg, y_test_predict),\n",
|
|||
|
" \"MAE_train\": mean_absolute_error(y_train_reg, y_train_predict),\n",
|
|||
|
" \"MAE_test\": mean_absolute_error(y_test_reg, y_test_predict)\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
"for model_name, results in results_reg.items():\n",
|
|||
|
" print(f\"Results for {model_name}:\")\n",
|
|||
|
" print(f\"MSE_train: {results['MSE_train']}\")\n",
|
|||
|
" print(f\"MSE_test: {results['MSE_test']}\")\n",
|
|||
|
" print(f\"R2_train: {results['R2_train']}\")\n",
|
|||
|
" print(f\"R2_test: {results['R2_test']}\")\n",
|
|||
|
" print(f\"MAE_train: {results['MAE_train']}\")\n",
|
|||
|
" print(f\"MAE_test: {results['MAE_test']}\")\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Best parameters for Logistic Regression: {'model__C': 10, 'model__solver': 'liblinear'}\n",
|
|||
|
"Best parameters for Random Forest Classifier: {'model__max_depth': None, 'model__n_estimators': 100}\n",
|
|||
|
"Best parameters for Gradient Boosting Classifier: {'model__learning_rate': 0.1, 'model__max_depth': 5, 'model__n_estimators': 300}\n",
|
|||
|
"Model: Logistic Regression\n",
|
|||
|
"Model: Random Forest Classifier\n",
|
|||
|
"Model: Gradient Boosting Classifier\n",
|
|||
|
"Results for Logistic Regression:\n",
|
|||
|
"Precision_train: 0.8744992796427931\n",
|
|||
|
"Precision_test: 0.8758164684756503\n",
|
|||
|
"Recall_train: 0.841789667896679\n",
|
|||
|
"Recall_test: 0.8431734317343174\n",
|
|||
|
"Accuracy_train: 0.841789667896679\n",
|
|||
|
"Accuracy_test: 0.8431734317343174\n",
|
|||
|
"ROC_AUC_test: 0.9355010585912291\n",
|
|||
|
"F1_train: 0.834330465466937\n",
|
|||
|
"F1_test: 0.8350142662043457\n",
|
|||
|
"MCC_test: 0.7851901595251827\n",
|
|||
|
"Cohen_kappa_test: 0.7657628342341151\n",
|
|||
|
"Confusion_matrix:\n",
|
|||
|
"[[173 3 0]\n",
|
|||
|
" [ 5 172 0]\n",
|
|||
|
" [ 17 60 112]]\n",
|
|||
|
"\n",
|
|||
|
"Results for Random Forest Classifier:\n",
|
|||
|
"Precision_train: 0.999080053300533\n",
|
|||
|
"Precision_test: 0.9746115352901856\n",
|
|||
|
"Recall_train: 0.9990774907749077\n",
|
|||
|
"Recall_test: 0.974169741697417\n",
|
|||
|
"Accuracy_train: 0.9990774907749077\n",
|
|||
|
"Accuracy_test: 0.974169741697417\n",
|
|||
|
"ROC_AUC_test: 0.9937592800800455\n",
|
|||
|
"F1_train: 0.9990775021680652\n",
|
|||
|
"F1_test: 0.974013999727198\n",
|
|||
|
"MCC_test: 0.9616316062344435\n",
|
|||
|
"Cohen_kappa_test: 0.96125843706283\n",
|
|||
|
"Confusion_matrix:\n",
|
|||
|
"[[176 0 0]\n",
|
|||
|
" [ 0 175 2]\n",
|
|||
|
" [ 4 8 177]]\n",
|
|||
|
"\n",
|
|||
|
"Results for Gradient Boosting Classifier:\n",
|
|||
|
"Precision_train: 0.999080053300533\n",
|
|||
|
"Precision_test: 0.9706437064370643\n",
|
|||
|
"Recall_train: 0.9990774907749077\n",
|
|||
|
"Recall_test: 0.9704797047970479\n",
|
|||
|
"Accuracy_train: 0.9990774907749077\n",
|
|||
|
"Accuracy_test: 0.9704797047970479\n",
|
|||
|
"ROC_AUC_test: 0.9927092854328317\n",
|
|||
|
"F1_train: 0.9990775021680652\n",
|
|||
|
"F1_test: 0.9703257159470196\n",
|
|||
|
"MCC_test: 0.9559528402941749\n",
|
|||
|
"Cohen_kappa_test: 0.9557185020271858\n",
|
|||
|
"Confusion_matrix:\n",
|
|||
|
"[[176 0 0]\n",
|
|||
|
" [ 0 173 4]\n",
|
|||
|
" [ 4 8 177]]\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier\n",
|
|||
|
"from sklearn.linear_model import LogisticRegression\n",
|
|||
|
"from sklearn.metrics import cohen_kappa_score, confusion_matrix, f1_score, matthews_corrcoef, precision_score, recall_score, roc_auc_score, accuracy_score\n",
|
|||
|
"from sklearn.model_selection import GridSearchCV, train_test_split\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"\n",
|
|||
|
"# Предположим, что у вас уже есть DataFrame df с колонкой 'charges'\n",
|
|||
|
"# df = pd.read_csv('your_data.csv')\n",
|
|||
|
"\n",
|
|||
|
"# Вычисляем 33-й и 66-й перцентили\n",
|
|||
|
"bins = [float('-inf'), df['charges'].quantile(0.33), df['charges'].quantile(0.66), float('inf')]\n",
|
|||
|
"labels = ['Low', 'Medium', 'High']\n",
|
|||
|
"\n",
|
|||
|
"# Создаем новый столбец 'ChangesGroup' на основе заданных границ\n",
|
|||
|
"df['ChangesGroup'] = pd.cut(df['charges'], bins=bins, labels=labels)\n",
|
|||
|
"\n",
|
|||
|
"# Преобразуем категориальные данные в числовые\n",
|
|||
|
"df['ChangesGroup'] = df['ChangesGroup'].cat.codes\n",
|
|||
|
"\n",
|
|||
|
"# Выбор признаков и целевой переменной для классификации\n",
|
|||
|
"X_class = df.drop(['charges', 'ChangesGroup'], axis=1)\n",
|
|||
|
"y_class = df['ChangesGroup']\n",
|
|||
|
"\n",
|
|||
|
"# Разделение данных\n",
|
|||
|
"X_train_class, X_test_class, y_train_class, y_test_class = train_test_split(X_class, y_class, test_size=0.2, random_state=13)\n",
|
|||
|
"\n",
|
|||
|
"# Выбор моделей для классификации\n",
|
|||
|
"models_class = {\n",
|
|||
|
" 'Logistic Regression': LogisticRegression(random_state=13, max_iter=5000, solver='liblinear'),\n",
|
|||
|
" 'Random Forest Classifier': RandomForestClassifier(random_state=13),\n",
|
|||
|
" 'Gradient Boosting Classifier': GradientBoostingClassifier(random_state=13)\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Создание конвейера для классификации\n",
|
|||
|
"pipelines_class = {}\n",
|
|||
|
"for name, model in models_class.items():\n",
|
|||
|
" pipelines_class[name] = Pipeline([\n",
|
|||
|
" ('scaler', StandardScaler()),\n",
|
|||
|
" ('model', model)\n",
|
|||
|
" ])\n",
|
|||
|
"\n",
|
|||
|
"# Определение сетки гиперпараметров для классификации\n",
|
|||
|
"param_grids_class = {\n",
|
|||
|
" 'Logistic Regression': {\n",
|
|||
|
" 'model__C': [0.1, 1, 10],\n",
|
|||
|
" 'model__solver': ['lbfgs', 'liblinear']\n",
|
|||
|
" },\n",
|
|||
|
" 'Random Forest Classifier': {\n",
|
|||
|
" 'model__n_estimators': [100, 200, 300],\n",
|
|||
|
" 'model__max_depth': [None, 10, 20, 30]\n",
|
|||
|
" },\n",
|
|||
|
" 'Gradient Boosting Classifier': {\n",
|
|||
|
" 'model__n_estimators': [100, 200, 300],\n",
|
|||
|
" 'model__learning_rate': [0.01, 0.1, 0.2],\n",
|
|||
|
" 'model__max_depth': [3, 5, 7]\n",
|
|||
|
" }\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Настройка гиперпараметров для классификации\n",
|
|||
|
"best_models_class = {}\n",
|
|||
|
"for name, pipeline in pipelines_class.items():\n",
|
|||
|
" grid_search = GridSearchCV(pipeline, param_grids_class[name], cv=5, scoring='accuracy', n_jobs=-1)\n",
|
|||
|
" grid_search.fit(X_train_class, y_train_class)\n",
|
|||
|
" best_models_class[name] = {\"model\": grid_search.best_estimator_}\n",
|
|||
|
" print(f'Best parameters for {name}: {grid_search.best_params_}')\n",
|
|||
|
"\n",
|
|||
|
"# Обучение моделей и оценка качества\n",
|
|||
|
"for model_name in best_models_class.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" model = best_models_class[model_name][\"model\"]\n",
|
|||
|
"\n",
|
|||
|
" model_pipeline = Pipeline([(\"scaler\", StandardScaler()), (\"model\", model)])\n",
|
|||
|
" model_pipeline = model_pipeline.fit(X_train_class, y_train_class)\n",
|
|||
|
"\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train_class)\n",
|
|||
|
" y_test_probs = model_pipeline.predict_proba(X_test_class)\n",
|
|||
|
" y_test_predict = model_pipeline.predict(X_test_class)\n",
|
|||
|
"\n",
|
|||
|
" best_models_class[model_name][\"pipeline\"] = model_pipeline\n",
|
|||
|
" best_models_class[model_name][\"probs\"] = y_test_probs\n",
|
|||
|
" best_models_class[model_name][\"preds\"] = y_test_predict\n",
|
|||
|
"\n",
|
|||
|
" best_models_class[model_name][\"Precision_train\"] = precision_score(y_train_class, y_train_predict, average='weighted')\n",
|
|||
|
" best_models_class[model_name][\"Precision_test\"] = precision_score(y_test_class, y_test_predict, average='weighted')\n",
|
|||
|
" best_models_class[model_name][\"Recall_train\"] = recall_score(y_train_class, y_train_predict, average='weighted')\n",
|
|||
|
" best_models_class[model_name][\"Recall_test\"] = recall_score(y_test_class, y_test_predict, average='weighted')\n",
|
|||
|
" best_models_class[model_name][\"Accuracy_train\"] = accuracy_score(y_train_class, y_train_predict)\n",
|
|||
|
" best_models_class[model_name][\"Accuracy_test\"] = accuracy_score(y_test_class, y_test_predict)\n",
|
|||
|
" best_models_class[model_name][\"ROC_AUC_test\"] = roc_auc_score(y_test_class, y_test_probs, multi_class='ovr')\n",
|
|||
|
" best_models_class[model_name][\"F1_train\"] = f1_score(y_train_class, y_train_predict, average='weighted')\n",
|
|||
|
" best_models_class[model_name][\"F1_test\"] = f1_score(y_test_class, y_test_predict, average='weighted')\n",
|
|||
|
" best_models_class[model_name][\"MCC_test\"] = matthews_corrcoef(y_test_class, y_test_predict)\n",
|
|||
|
" best_models_class[model_name][\"Cohen_kappa_test\"] = cohen_kappa_score(y_test_class, y_test_predict)\n",
|
|||
|
" best_models_class[model_name][\"Confusion_matrix\"] = confusion_matrix(y_test_class, y_test_predict)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод результатов\n",
|
|||
|
"for model_name, results in best_models_class.items():\n",
|
|||
|
" print(f\"Results for {model_name}:\")\n",
|
|||
|
" print(f\"Precision_train: {results['Precision_train']}\")\n",
|
|||
|
" print(f\"Precision_test: {results['Precision_test']}\")\n",
|
|||
|
" print(f\"Recall_train: {results['Recall_train']}\")\n",
|
|||
|
" print(f\"Recall_test: {results['Recall_test']}\")\n",
|
|||
|
" print(f\"Accuracy_train: {results['Accuracy_train']}\")\n",
|
|||
|
" print(f\"Accuracy_test: {results['Accuracy_test']}\")\n",
|
|||
|
" print(f\"ROC_AUC_test: {results['ROC_AUC_test']}\")\n",
|
|||
|
" print(f\"F1_train: {results['F1_train']}\")\n",
|
|||
|
" print(f\"F1_test: {results['F1_test']}\")\n",
|
|||
|
" print(f\"MCC_test: {results['MCC_test']}\")\n",
|
|||
|
" print(f\"Cohen_kappa_test: {results['Cohen_kappa_test']}\")\n",
|
|||
|
" print(f\"Confusion_matrix:\\n{results['Confusion_matrix']}\")\n",
|
|||
|
" print()\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAc8AAAQ9CAYAAAA/GsaeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAD9Y0lEQVR4nOzdd1gU1xoG8HdpC9J7UVDsYC/RoCigRCyxBHvUgDUmYu83YotKNBqNlcQY0URiEluUJMaKXWPDFoKiqFgADdJD2537B2F0BXQHF3Hh/T3PPDd75szMNyuXj3PmzDkyQRAEEBERkdp0yjsAIiIibcPkSUREJBGTJxERkURMnkRERBIxeRIREUnE5ElERCQRkycREZFETJ5EREQSMXkSERFJxORJFYK3tze8vb01dr4aNWogMDBQY+cjQCaTYe7cueUdBpFGMHmSRoWFhUEmk+HcuXPlHcpLnTx5EnPnzkVKSkqZXqdGjRqQyWTiZmxsjFatWmHz5s1lel0iKjt65R0AkSbs27dP8jEnT57EvHnzEBgYCAsLC5V9MTEx0NHR3N+WTZs2xeTJkwEADx8+xDfffIOAgADk5ORg5MiRGrvOm+zff/+Fnh5/5VDFwJ9kqhAMDAw0ej65XK7R81WtWhWDBw8WPwcGBqJmzZpYvnz5a0+emZmZMDY2fq3XBABDQ8PXfk2issJuWyoXFy9eRJcuXWBmZgYTExN07NgRp0+fLlLv8uXL8PLygpGREapVq4YFCxZg48aNkMlkuH37tlivuGeeq1atQoMGDVClShVYWlqiZcuWCA8PBwDMnTsXU6dOBQC4urqKXaqF5yzumWdKSgomTpyIGjVqQC6Xo1q1avjggw/w+PFjyfdva2uL+vXr4+bNmyrlSqUSK1asQIMGDWBoaAh7e3t8+OGHePLkSZF6c+fOhZOTE6pUqQIfHx/89ddfReIu7EY/cuQIPv74Y9jZ2aFatWri/t9//x3t2rWDsbExTE1N0a1bN1y7dk3lWgkJCRg6dCiqVasGuVwOR0dH9OzZU+X7P3fuHPz8/GBjYwMjIyO4urpi2LBhKucp7pmnOj8Hhfdw4sQJTJo0Cba2tjA2NsZ7772HR48eqfuVE2kUW5702l27dg3t2rWDmZkZpk2bBn19fXz11Vfw9vbGkSNH0Lp1awDA/fv34ePjA5lMhpkzZ8LY2BjffPONWq3C9evXY9y4cejTpw/Gjx+P7OxsXL58GWfOnMH7778Pf39/XL9+HT/88AOWL18OGxsbAAVJrTgZGRlo164doqOjMWzYMDRv3hyPHz/G7t27ce/ePfF4deXn5+PevXuwtLRUKf/www8RFhaGoUOHYty4cYiLi8Pq1atx8eJFnDhxAvr6+gCAmTNnYsmSJejevTv8/Pxw6dIl+Pn5ITs7u9jrffzxx7C1tcXs2bORmZkJAPjuu+8QEBAAPz8/LF68GFlZWVi3bh08PT1x8eJF1KhRAwDQu3dvXLt2DWPHjkWNGjWQlJSE/fv34+7du+LnTp06wdbWFjNmzICFhQVu376NHTt2vPA7UPfnoNDYsWNhaWmJOXPm4Pbt21ixYgWCgoLw448/SvruiTRCINKgjRs3CgCEs2fPllinV69egoGBgXDz5k2x7MGDB4KpqanQvn17sWzs2LGCTCYTLl68KJb9888/gpWVlQBAiIuLE8u9vLwELy8v8XPPnj2FBg0avDDWzz//vMh5ClWvXl0ICAgQP8+ePVsAIOzYsaNIXaVS+cLrVK9eXejUqZPw6NEj4dGjR8KVK1eEIUOGCACEMWPGiPWOHTsmABC2bNmicvzevXtVyhMSEgQ9PT2hV69eKvXmzp0rAFCJu/Dfw9PTU8jPzxfL09PTBQsLC2HkyJEq50hISBDMzc3F8idPnggAhM8//7zE+9u5c+dL/80FQRAACHPmzBE/q/tzUHgPvr6+Kt/1xIkTBV1dXSElJeWF1yUqC+y2pddKoVBg37596NWrF2rWrCmWOzo64v3338fx48eRlpYGANi7dy88PDzQtGlTsZ6VlRUGDRr00utYWFjg3r17OHv2rEbi3r59O5o0aYL33nuvyD6ZTPbS4/ft2wdbW1vY2tqiUaNG+O677zB06FB8/vnnYp2ff/4Z5ubmeOedd/D48WNxa9GiBUxMTHD48GEAwMGDB5Gfn4+PP/5Y5Rpjx44t8fojR46Erq6u+Hn//v1ISUnBwIEDVa6lq6uL1q1bi9cyMjKCgYEBIiMji3QdFyocbBUREYG8vLyXfheAtJ+DQqNGjVL5rtu1aweFQoE7d+6odU0iTWLypNfq0aNHyMrKQr169Yrsc3Nzg1KpRHx8PADgzp07qF27dpF6xZU9b/r06TAxMUGrVq1Qp04djBkzBidOnCh13Ddv3kTDhg1LfXzr1q2xf/9+7N27F0uXLoWFhQWePHmiMtDpxo0bSE1NhZ2dnZhoC7eMjAwkJSUBgJgsnv8erKysinQDF3J1dVX5fOPGDQBAhw4dilxr37594rXkcjkWL16M33//Hfb29mjfvj2WLFmChIQE8VxeXl7o3bs35s2bBxsbG/Ts2RMbN25ETk5Oid+HlJ+DQi4uLiqfC++1pKROVJb4zJMqJDc3N8TExCAiIgJ79+7F9u3bsXbtWsyePRvz5s177fHY2NjA19cXAODn54f69evj3XffxZdffolJkyYBKBgEZGdnhy1bthR7jpKex6rDyMhI5bNSqQRQ8NzTwcGhSP1nXymZMGECunfvjl27duGPP/5AcHAwQkJCcOjQITRr1gwymQzbtm3D6dOnsWfPHvzxxx8YNmwYli1bhtOnT8PExKTUcT/r2ZbzswRB0Mj5iaRg8qTXytbWFlWqVEFMTEyRfX///Td0dHTg7OwMAKhevTpiY2OL1CuurDjGxsbo378/+vfvj9zcXPj7+2PhwoWYOXMmDA0N1epuLVSrVi1cvXpV7fov061bN3h5eWHRokX48MMPYWxsjFq1auHAgQNo27ZtkWT3rOrVqwMo+B6ebVH+888/arfCatWqBQCws7MTk/rL6k+ePBmTJ0/GjRs30LRpUyxbtgzff/+9WOftt9/G22+/jYULFyI8PByDBg3C1q1bMWLEiCLnk/JzQPQmYrctvVa6urro1KkTfvnlF5VXHRITExEeHg5PT0+YmZkBKGihnTp1ClFRUWK95OTkEltmz/rnn39UPhsYGMDd3R2CIIjP5QrfdVRnhqHevXvj0qVL2LlzZ5F9pW35TJ8+Hf/88w/Wr18PAOjXrx8UCgU+/fTTInXz8/PFODt27Ag9PT2sW7dOpc7q1avVvrafnx/MzMywaNGiYp9TFr4CkpWVVWQEb61atWBqaip2yz558qTId1D4nLqkrlspPwdEbyK2PKlMfPvtt9i7d2+R8vHjx2PBggXYv38/PD098fHHH0NPTw9fffUVcnJysGTJErHutGnT8P333+Odd97B2LFjxVdVXFxckJyc/MKWY6dOneDg4IC2bdvC3t4e0dHRWL16Nbp16wZTU1MAQIsWLQAAn3zyCQYMGAB9fX1079692AkEpk6dim3btqFv374YNmwYWrRogeTkZOzevRuhoaFo0qSJ5O+oS5cuaNiwIb744guMGTMGXl5e+PDDDxESEoKoqCh06tQJ+vr6uHHjBn7++Wd8+eWX6NOnD+zt7TF+/HgsW7YMPXr0QOfOnXHp0iX8/vvvsLGxUatFbWZmhnXr1mHIkCFo3rw5BgwYAFtbW9y9exe//vor2rZti9WrV+P69evo2LEj+vXrB3d3d+jp6WHnzp1ITEzEgAEDAACbNm3C2rVr8d5776FWrVpIT0/H+vXrYWZmhq5du5YYg7o/B0RvpPId7EsVTeFrBSVt8fHxgiAIwoULFwQ/Pz/BxMREqFKliuDj4yOcPHmyyPkuXrwotGvXTpDL5UK1atWEkJAQYeXKlQIAISEhQaz3/KsqX331ldC+fXvB2tpakMvlQq1atYSpU6cKqampKuf/9NNPhapVqwo6Ojoqr608/6qKIBS8JhMUFCRUrVpVMDAwEKpVqyYEBAQIjx8/fuF3Ur1
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1200x1000 with 6 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from matplotlib import pyplot as plt\n",
|
|||
|
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
|||
|
"\n",
|
|||
|
"# Количество моделей\n",
|
|||
|
"num_models = len(best_models_class)\n",
|
|||
|
"\n",
|
|||
|
"# Создание фигуры и осей для отображения матриц ошибок\n",
|
|||
|
"fig, ax = plt.subplots(num_models, 1, figsize=(12, 10), sharex=False, sharey=False)\n",
|
|||
|
"\n",
|
|||
|
"# Перебор моделей и отображение матриц ошибок\n",
|
|||
|
"for index, key in enumerate(best_models_class.keys()):\n",
|
|||
|
" c_matrix = best_models_class[key][\"Confusion_matrix\"]\n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[\"Low\", \"Medium\", \"High\"]\n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
" disp.ax_.set_title(key)\n",
|
|||
|
"\n",
|
|||
|
"# Настройка расположения подзаголовков\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
|
|||
|
"plt.show()\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABOoAAAQ9CAYAAADnOaJIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeXgUVfbw8W919Zru7GRhCSQEZF8EBBEVHJDgoA6KCi4jIqgooIDjNqO4j79x13FhHEfxnYFRXEdFEARxAVRkUUBAgiCRkH3v9F71/tGkSUgCCSR0Es7nIc+Trr6pvtXdVN8+de65iq7rOkIIIYQQQgghhBBCiLAyhLsDQgghhBBCCCGEEEIICdQJIYQQQgghhBBCCNEiSKBOCCGEEEIIIYQQQogWQAJ1QgghhBBCCCGEEEK0ABKoE0IIIYQQQgghhBCiBZBAnRBCCCGEEEIIIYQQLYAE6oQQQgghhBBCCCGEaAEkUCeEEEIIIYQQQgghRAsggTohhBBCCCGEEEIIIVoACdQJ0UT27duHoigsXLgw3F0RJ+C6664jNTU13N1oMxRF4YEHHgh3N8Ju1KhRjBo1KnS7JZ4vjuyjEEKcbC3x3CgaT8ZSQQ888ACKojSo7ckYL8nnfMPJuE2EmwTqhGiAhQsXoigK33//fbi70myqBhNVPyaTidTUVG699VZKSkrC3T0BvPTSSyiKwrBhw457H9nZ2TzwwANs2bKl6TrWwq1Zs6bWe7tr165ce+21/PLLL+HuXqOsW7eOBx54QP5PCiFaHRlLlYS7e6esqvde1Y/RaKRjx45cd911HDhwINzdE0eQcZsQYAx3B4RoK7p06YLL5cJkMoW7Kyfk5ZdfxuFw4HQ6WbVqFX//+9/ZtGkTX3/9dbi7dlL885//RNO0cHejTosWLSI1NZXvvvuOzMxMunXr1uh9ZGdn8+CDD5KamsrAgQObvpMt2K233soZZ5yBz+dj06ZNvPLKKyxdupStW7fSoUOHk9qX4z1frFu3jgcffJDrrruOmJiY5umcEEKEiYyl2oaWOpZ66KGHSEtLw+12880337Bw4UK+/vprtm3bhtVqbfLHu/fee7n77rubfL+nChm3iVOZZNQJ0UQURcFqtaKqari7Uq/Kyspjtrnsssu45ppruOmmm1iyZAmTJk1i7dq1fPfddyehh4dpmobb7T6pjwlgMpmwWCwn/XGPZe/evaxbt46nn36ahIQEFi1aFO4utTrnnHMO11xzDVOnTuXvf/87Tz75JEVFRbzxxhv1/o3T6WyWvrSG84UQQpxsreHcKGOpY2upY6kLLriAa665hunTp/Pqq6/ypz/9iT179vDhhx82y+MZjcZmCQCeKmTcJk5lEqgToonUVbvguuuuw+FwcODAASZMmIDD4SAhIYE//elPBAKBGn+vaRrPPvssffr0wWq1kpSUxE033URxcXGNdv/73/8YP348HTp0wGKxkJ6ezsMPP1xrf6NGjaJv375s3LiRc889l4iICP785z83+rjOOeccAPbs2VNj+7fffsu4ceOIjo4mIiKCkSNHsnbt2lp/v2bNGoYMGYLVaiU9PZ1//OMfddbsUBSFWbNmsWjRIvr06YPFYmH58uUAHDhwgOuvv56kpCQsFgt9+vThtddeq/VYf//73+nTpw8RERHExsYyZMgQFi9eHLq/vLycOXPmkJqaisViITExkfPPP59NmzaF2tRVV8XpdHL77beTkpKCxWKhR48ePPnkk+i6XucxfPDBB/Tt2zfU16rjqG7nzp3s37+/rqe8TosWLSI2Npbx48dz2WWX1RuoKykpYe7cuaFj7NSpE9deey0FBQWsWbOGM844A4CpU6eGphRUvWdTU1O57rrrau3zyBoYXq+X+fPnM3jwYKKjo7Hb7Zxzzjl8/vnnDT6eKrm5uRiNRh588MFa9+3atQtFUXjhhRcA8Pl8PPjgg3Tv3h2r1Up8fDxnn302K1eubPTjAvzud78DgkFQODxl6aeffuKqq64iNjaWs88+O9T+P//5D4MHD8ZmsxEXF8fkyZPJysqqtd9XXnmF9PR0bDYbQ4cO5auvvqrVpr5aJzt37uSKK64gISEBm81Gjx49+Mtf/hLq3x133AFAWlpa6PXbt29fs/RRCCFONhlLyViqOcdSR6rvddm5cyeXXXYZcXFxWK1WhgwZUiuY15AxSV2vkcfjYe7cuSQkJBAZGcnFF1/Mb7/9Vqtv9dX5q2ufr7/+Or/73e9ITEzEYrHQu3dvXn755QY9B8d6vY8k47aFNbbLuE00F5n6KkQzCwQCZGRkMGzYMJ588kk+++wznnrqKdLT07n55ptD7W666SYWLlzI1KlTufXWW9m7dy8vvPACmzdvZu3ataFU64ULF+JwOJg3bx4Oh4PVq1czf/58ysrKeOKJJ2o8dmFhIRdccAGTJ0/mmmuuISkpqdH9r/owiY2NDW1bvXo1F1xwAYMHD+b+++/HYDCEBglfffUVQ4cOBWDz5s2MGzeO9u3b8+CDDxIIBHjooYdISEio87FWr17NkiVLmDVrFu3atSM1NZXc3FzOPPPM0MAtISGBZcuWMW3aNMrKypgzZw4QnGZx6623ctlll3Hbbbfhdrv58ccf+fbbb7nqqqsAmDFjBu+88w6zZs2id+/eFBYW8vXXX7Njxw4GDRpUZ590Xefiiy/m888/Z9q0aQwcOJBPP/2UO+64gwMHDvDMM8/UaP/111/z3nvvccsttxAZGcnzzz/PxIkT2b9/P/Hx8aF2vXr1YuTIkaxZs6ZBr8OiRYu49NJLMZvNXHnllbz88sts2LAhFHgDqKio4JxzzmHHjh1cf/31DBo0iIKCAj788EN+++03evXqxUMPPcT8+fO58cYbQwPUs846q0F9qFJWVsarr77KlVdeyQ033EB5eTn/+te/yMjI4LvvvmvUlNqkpCRGjhzJkiVLuP/++2vc99Zbb6GqKpdffjkQHPA89thjTJ8+naFDh1JWVsb333/Ppk2bOP/88xt1DHB4YF79dQG4/PLL6d69O3/9619DXyAeffRR7rvvPq644gqmT59Ofn4+f//73zn33HPZvHlzaDrDv/71L2666SbOOuss5syZwy+//MLFF19MXFwcKSkpR+3Pjz/+yDnnnIPJZOLGG28kNTWVPXv28NFHH/Hoo49y6aWX8vPPP/Pf//6XZ555hnbt2gGE/j+djD4KIUQ4yFhKxlJNMZY6Ul2vy/bt2xkxYgQdO3bk7rvvxm63s2TJEiZMmMC7777LJZdcAhz/mGT69On85z//4aqrruKss85i9erVjB8//rj6X+Xll1+mT58+XHzxxRiNRj766CNuueUWNE1j5syZ9f5dQ17vI8m47TAZt4lmpQshjun111/XAX3Dhg31ttm7d68O6K+//npo25QpU3RAf+ihh2q0Pf300/XBgweHbn/11Vc6oC9atKhGu+XLl9faXllZWeuxb7rpJj0iIkJ3u92hbSNHjtQBfcGCBQ06xvvvv18H9F27dun5+fn6vn379Ndee0232Wx6QkKC7nQ6dV3XdU3T9O7du+sZGRm6pmk1+pWWlqaff/75oW0XXXSRHhERoR84cCC0bffu3brRaNSPPP0AusFg0Ldv315j+7Rp0/T27dvrBQUFNbZPnjxZj46ODj0ff/jDH/Q+ffoc9Rijo6P1mTNnHrXNlClT9C5duoRuf/DBBzqgP/LIIzXaXXbZZbqiKHpmZmaNYzCbzTW2/fDDDzqg//3vf691vCNHjjxqX6p8//33OqCvXLlS1/Xga9CpUyf9tttuq9Fu/vz5OqC/9957tfZR9Vpt2LCh1vu0SpcuXfQpU6bU2j5y5MgaffX7/brH46nRpri4WE9KStKvv/76GtsB/f777z/q8f3jH//QAX3r1q01tvfu3Vv/3e9+F7o9YMAAffz48UfdV10+//xzHdBfe+01PT8/X8/OztaXLl2qp6am6oqihP5fV/0fuPLKK2v8/b59+3RVVfVHH320xvatW7fqRqMxtN3r9eqJiYn6wIEDazw/r7zySq3Xu67zxbnnnqtHRkbqv/76a43
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1500x1000 with 6 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from matplotlib import pyplot as plt\n",
|
|||
|
"from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
|
|||
|
"\n",
|
|||
|
"# Создание фигуры и осей для отображения графиков\n",
|
|||
|
"fig, ax = plt.subplots(len(best_models_reg), 2, figsize=(15, 10), sharex=False, sharey=False)\n",
|
|||
|
"ax = ax.flatten()\n",
|
|||
|
"\n",
|
|||
|
"# Перебор моделей и отображение графиков\n",
|
|||
|
"for index, (name, model_results) in enumerate(results_reg.items()):\n",
|
|||
|
" y_test_reg = y_test_reg\n",
|
|||
|
" y_pred_reg = model_results[\"preds_test\"]\n",
|
|||
|
"\n",
|
|||
|
" # График фактических значений против предсказанных значений\n",
|
|||
|
" ax[index * 2].scatter(y_test_reg, y_pred_reg, alpha=0.5)\n",
|
|||
|
" ax[index * 2].plot([min(y_test_reg), max(y_test_reg)], [min(y_test_reg), max(y_test_reg)], color='red', linestyle='--')\n",
|
|||
|
" ax[index * 2].set_xlabel('Actual Values')\n",
|
|||
|
" ax[index * 2].set_ylabel('Predicted Values')\n",
|
|||
|
" ax[index * 2].set_title(f'{name}: Actual vs Predicted')\n",
|
|||
|
"\n",
|
|||
|
" # График остатков\n",
|
|||
|
" residuals = y_test_reg - y_pred_reg\n",
|
|||
|
" ax[index * 2 + 1].scatter(y_pred_reg, residuals, alpha=0.5)\n",
|
|||
|
" ax[index * 2 + 1].axhline(y=0, color='red', linestyle='--')\n",
|
|||
|
" ax[index * 2 + 1].set_xlabel('Predicted Values')\n",
|
|||
|
" ax[index * 2 + 1].set_ylabel('Residuals')\n",
|
|||
|
" ax[index * 2 + 1].set_title(f'{name}: Residuals vs Predicted')\n",
|
|||
|
"\n",
|
|||
|
"# Настройка расположения подзаголовков\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
|
|||
|
"plt.show()\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "aimenv",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.12.5"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|