476 lines
351 KiB
Plaintext
Raw Permalink Normal View History

2024-12-06 21:40:44 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. Бизнес-цели для решения задач\n",
"регрессии - предсказание цены страховки (целевой столбец - charges)\n",
"классификации - определить цену страховки в категориальных параметрах (низкая, высокая, нормальная) на основе средней цены\n",
"\n",
"2. Достижимый уровень качества\n",
"данный датасет не требует от человека применения сложных средств и методов для разметки данных. Также я буду исходить из идеи, что все необходимые сведения предоставлены. Исхрдя из этого можно сказать, что модель может достичь \"хорошего\" уровня (цитата из презентации)\n",
"\n",
"3. Выбор ориентира\n",
"регрессия - средняя стоимость страховки\n",
"классификация - наиболее часто встречающаяся категория"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подготовка данных (взято из 3 ЛР)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2772\n",
"было 2772\n",
"age 39.10966810966811 14.081459420836477\n",
"bmi 30.70134920634921 6.1294486949652205\n",
"children 1.1026753434562546 1.2157555494600176\n",
"charges 13325.498588795157 12200.175109274192\n",
"стало 2710\n"
]
}
],
"source": [
"import pandas as pd\n",
"df = pd.read_csv(\"../dataset.csv\")\n",
"print(df.shape[0])\n",
"\n",
"print(\"было \", df.shape[0])\n",
"for column in df.select_dtypes(include=['int', 'float']).columns:\n",
" mean = df[column].mean()\n",
" std_dev = df[column].std()\n",
" print(column, mean, std_dev)\n",
" \n",
" lower_bound = mean - 3 * std_dev\n",
" upper_bound = mean + 3 * std_dev\n",
" \n",
" df = df[(df[column] <= upper_bound) & (df[column] >= lower_bound)]\n",
" \n",
"print(\"стало \", df.shape[0])\n",
"df = pd.get_dummies(df, columns=['smoker', 'sex', 'region'])\n",
"df = df.reset_index(drop=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"5-6. Выбор моделей и построение конвеера"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best parameters for Linear Regression: {}\n",
"Best parameters for Random Forest Regressor: {'model__max_depth': None, 'model__n_estimators': 200}\n",
"Best parameters for Gradient Boosting Regressor: {'model__learning_rate': 0.1, 'model__max_depth': 7, 'model__n_estimators': 200}\n",
"Model: Linear Regression\n",
"Model: Random Forest Regressor\n",
"Model: Gradient Boosting Regressor\n",
"Results for Linear Regression:\n",
"MSE_train: 34444659.26088208\n",
"MSE_test: 34814365.10649261\n",
"R2_train: 0.7519227670167513\n",
"R2_test: 0.7526515685852002\n",
"MAE_train: 4063.4298138535955\n",
"MAE_test: 4177.000255554095\n",
"Results for Random Forest Regressor:\n",
"MSE_train: 1068690.9196129853\n",
"MSE_test: 6893069.254237878\n",
"R2_train: 0.9923030771114929\n",
"R2_test: 0.9510262541782973\n",
"MAE_train: 458.22401190138396\n",
"MAE_test: 1239.9530445735054\n",
"Results for Gradient Boosting Regressor:\n",
"MSE_train: 314879.3352948071\n",
"MSE_test: 5789682.017738877\n",
"R2_train: 0.9977321768918687\n",
"R2_test: 0.9588655785881008\n",
"MAE_train: 241.34572975609007\n",
"MAE_test: 870.203742123879\n"
]
}
],
"source": [
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor\n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
"from sklearn.model_selection import GridSearchCV, train_test_split\n",
"from sklearn.pipeline import Pipeline\n",
"\n",
"X_reg = df.drop(columns=['charges'])\n",
"y_reg = df['charges']\n",
"\n",
"X_train_reg, X_test_reg, y_train_reg, y_test_reg = train_test_split(X_reg, y_reg, test_size=0.2, random_state=13)\n",
"\n",
"models_reg = {\n",
" 'Linear Regression': LinearRegression(),\n",
" 'Random Forest Regressor': RandomForestRegressor(random_state=42),\n",
" 'Gradient Boosting Regressor': GradientBoostingRegressor(random_state=42)\n",
"}\n",
"\n",
"pipelines_reg = {}\n",
"for name, model in models_reg.items():\n",
" pipelines_reg[name] = Pipeline([\n",
" ('scaler', StandardScaler()),\n",
" ('model', model)\n",
" ])\n",
"\n",
"param_grids_reg = {\n",
" 'Linear Regression': {},\n",
" 'Random Forest Regressor': {\n",
" 'model__n_estimators': [100, 200, 300],\n",
" 'model__max_depth': [None, 10, 20, 30]\n",
" },\n",
" 'Gradient Boosting Regressor': {\n",
" 'model__n_estimators': [100, 200, 300],\n",
" 'model__learning_rate': [0.01, 0.1, 0.2],\n",
" 'model__max_depth': [3, 5, 7]\n",
" }\n",
"}\n",
"\n",
"best_models_reg = {}\n",
"for name, pipeline in pipelines_reg.items():\n",
" grid_search = GridSearchCV(pipeline, param_grids_reg[name], cv=5, scoring='neg_mean_squared_error')\n",
" grid_search.fit(X_train_reg, y_train_reg)\n",
" best_models_reg[name] = grid_search.best_estimator_\n",
" print(f'Best parameters for {name}: {grid_search.best_params_}')\n",
"\n",
"results_reg = {}\n",
"for model_name in best_models_reg.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model_pipeline = best_models_reg[model_name]\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train_reg)\n",
" y_test_predict = model_pipeline.predict(X_test_reg)\n",
"\n",
" results_reg[model_name] = {\n",
" \"pipeline\": model_pipeline,\n",
" \"preds_train\": y_train_predict,\n",
" \"preds_test\": y_test_predict,\n",
" \"MSE_train\": mean_squared_error(y_train_reg, y_train_predict),\n",
" \"MSE_test\": mean_squared_error(y_test_reg, y_test_predict),\n",
" \"R2_train\": r2_score(y_train_reg, y_train_predict),\n",
" \"R2_test\": r2_score(y_test_reg, y_test_predict),\n",
" \"MAE_train\": mean_absolute_error(y_train_reg, y_train_predict),\n",
" \"MAE_test\": mean_absolute_error(y_test_reg, y_test_predict)\n",
" }\n",
"\n",
"for model_name, results in results_reg.items():\n",
" print(f\"Results for {model_name}:\")\n",
" print(f\"MSE_train: {results['MSE_train']}\")\n",
" print(f\"MSE_test: {results['MSE_test']}\")\n",
" print(f\"R2_train: {results['R2_train']}\")\n",
" print(f\"R2_test: {results['R2_test']}\")\n",
" print(f\"MAE_train: {results['MAE_train']}\")\n",
" print(f\"MAE_test: {results['MAE_test']}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best parameters for Logistic Regression: {'model__C': 10, 'model__solver': 'liblinear'}\n",
"Best parameters for Random Forest Classifier: {'model__max_depth': None, 'model__n_estimators': 100}\n",
"Best parameters for Gradient Boosting Classifier: {'model__learning_rate': 0.1, 'model__max_depth': 5, 'model__n_estimators': 300}\n",
"Model: Logistic Regression\n",
"Model: Random Forest Classifier\n",
"Model: Gradient Boosting Classifier\n",
"Results for Logistic Regression:\n",
"Precision_train: 0.8744992796427931\n",
"Precision_test: 0.8758164684756503\n",
"Recall_train: 0.841789667896679\n",
"Recall_test: 0.8431734317343174\n",
"Accuracy_train: 0.841789667896679\n",
"Accuracy_test: 0.8431734317343174\n",
"ROC_AUC_test: 0.9355010585912291\n",
"F1_train: 0.834330465466937\n",
"F1_test: 0.8350142662043457\n",
"MCC_test: 0.7851901595251827\n",
"Cohen_kappa_test: 0.7657628342341151\n",
"Confusion_matrix:\n",
"[[173 3 0]\n",
" [ 5 172 0]\n",
" [ 17 60 112]]\n",
"\n",
"Results for Random Forest Classifier:\n",
"Precision_train: 0.999080053300533\n",
"Precision_test: 0.9746115352901856\n",
"Recall_train: 0.9990774907749077\n",
"Recall_test: 0.974169741697417\n",
"Accuracy_train: 0.9990774907749077\n",
"Accuracy_test: 0.974169741697417\n",
"ROC_AUC_test: 0.9937592800800455\n",
"F1_train: 0.9990775021680652\n",
"F1_test: 0.974013999727198\n",
"MCC_test: 0.9616316062344435\n",
"Cohen_kappa_test: 0.96125843706283\n",
"Confusion_matrix:\n",
"[[176 0 0]\n",
" [ 0 175 2]\n",
" [ 4 8 177]]\n",
"\n",
"Results for Gradient Boosting Classifier:\n",
"Precision_train: 0.999080053300533\n",
"Precision_test: 0.9706437064370643\n",
"Recall_train: 0.9990774907749077\n",
"Recall_test: 0.9704797047970479\n",
"Accuracy_train: 0.9990774907749077\n",
"Accuracy_test: 0.9704797047970479\n",
"ROC_AUC_test: 0.9927092854328317\n",
"F1_train: 0.9990775021680652\n",
"F1_test: 0.9703257159470196\n",
"MCC_test: 0.9559528402941749\n",
"Cohen_kappa_test: 0.9557185020271858\n",
"Confusion_matrix:\n",
"[[176 0 0]\n",
" [ 0 173 4]\n",
" [ 4 8 177]]\n",
"\n"
]
}
],
"source": [
"from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import cohen_kappa_score, confusion_matrix, f1_score, matthews_corrcoef, precision_score, recall_score, roc_auc_score, accuracy_score\n",
"from sklearn.model_selection import GridSearchCV, train_test_split\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"import pandas as pd\n",
"\n",
"# Предположим, что у вас уже есть DataFrame df с колонкой 'charges'\n",
"# df = pd.read_csv('your_data.csv')\n",
"\n",
"# Вычисляем 33-й и 66-й перцентили\n",
"bins = [float('-inf'), df['charges'].quantile(0.33), df['charges'].quantile(0.66), float('inf')]\n",
"labels = ['Low', 'Medium', 'High']\n",
"\n",
"# Создаем новый столбец 'ChangesGroup' на основе заданных границ\n",
"df['ChangesGroup'] = pd.cut(df['charges'], bins=bins, labels=labels)\n",
"\n",
"# Преобразуем категориальные данные в числовые\n",
"df['ChangesGroup'] = df['ChangesGroup'].cat.codes\n",
"\n",
"# Выбор признаков и целевой переменной для классификации\n",
"X_class = df.drop(['charges', 'ChangesGroup'], axis=1)\n",
"y_class = df['ChangesGroup']\n",
"\n",
"# Разделение данных\n",
"X_train_class, X_test_class, y_train_class, y_test_class = train_test_split(X_class, y_class, test_size=0.2, random_state=13)\n",
"\n",
"# Выбор моделей для классификации\n",
"models_class = {\n",
" 'Logistic Regression': LogisticRegression(random_state=13, max_iter=5000, solver='liblinear'),\n",
" 'Random Forest Classifier': RandomForestClassifier(random_state=13),\n",
" 'Gradient Boosting Classifier': GradientBoostingClassifier(random_state=13)\n",
"}\n",
"\n",
"# Создание конвейера для классификации\n",
"pipelines_class = {}\n",
"for name, model in models_class.items():\n",
" pipelines_class[name] = Pipeline([\n",
" ('scaler', StandardScaler()),\n",
" ('model', model)\n",
" ])\n",
"\n",
"# Определение сетки гиперпараметров для классификации\n",
"param_grids_class = {\n",
" 'Logistic Regression': {\n",
" 'model__C': [0.1, 1, 10],\n",
" 'model__solver': ['lbfgs', 'liblinear']\n",
" },\n",
" 'Random Forest Classifier': {\n",
" 'model__n_estimators': [100, 200, 300],\n",
" 'model__max_depth': [None, 10, 20, 30]\n",
" },\n",
" 'Gradient Boosting Classifier': {\n",
" 'model__n_estimators': [100, 200, 300],\n",
" 'model__learning_rate': [0.01, 0.1, 0.2],\n",
" 'model__max_depth': [3, 5, 7]\n",
" }\n",
"}\n",
"\n",
"# Настройка гиперпараметров для классификации\n",
"best_models_class = {}\n",
"for name, pipeline in pipelines_class.items():\n",
" grid_search = GridSearchCV(pipeline, param_grids_class[name], cv=5, scoring='accuracy', n_jobs=-1)\n",
" grid_search.fit(X_train_class, y_train_class)\n",
" best_models_class[name] = {\"model\": grid_search.best_estimator_}\n",
" print(f'Best parameters for {name}: {grid_search.best_params_}')\n",
"\n",
"# Обучение моделей и оценка качества\n",
"for model_name in best_models_class.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = best_models_class[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"scaler\", StandardScaler()), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train_class, y_train_class)\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train_class)\n",
" y_test_probs = model_pipeline.predict_proba(X_test_class)\n",
" y_test_predict = model_pipeline.predict(X_test_class)\n",
"\n",
" best_models_class[model_name][\"pipeline\"] = model_pipeline\n",
" best_models_class[model_name][\"probs\"] = y_test_probs\n",
" best_models_class[model_name][\"preds\"] = y_test_predict\n",
"\n",
" best_models_class[model_name][\"Precision_train\"] = precision_score(y_train_class, y_train_predict, average='weighted')\n",
" best_models_class[model_name][\"Precision_test\"] = precision_score(y_test_class, y_test_predict, average='weighted')\n",
" best_models_class[model_name][\"Recall_train\"] = recall_score(y_train_class, y_train_predict, average='weighted')\n",
" best_models_class[model_name][\"Recall_test\"] = recall_score(y_test_class, y_test_predict, average='weighted')\n",
" best_models_class[model_name][\"Accuracy_train\"] = accuracy_score(y_train_class, y_train_predict)\n",
" best_models_class[model_name][\"Accuracy_test\"] = accuracy_score(y_test_class, y_test_predict)\n",
" best_models_class[model_name][\"ROC_AUC_test\"] = roc_auc_score(y_test_class, y_test_probs, multi_class='ovr')\n",
" best_models_class[model_name][\"F1_train\"] = f1_score(y_train_class, y_train_predict, average='weighted')\n",
" best_models_class[model_name][\"F1_test\"] = f1_score(y_test_class, y_test_predict, average='weighted')\n",
" best_models_class[model_name][\"MCC_test\"] = matthews_corrcoef(y_test_class, y_test_predict)\n",
" best_models_class[model_name][\"Cohen_kappa_test\"] = cohen_kappa_score(y_test_class, y_test_predict)\n",
" best_models_class[model_name][\"Confusion_matrix\"] = confusion_matrix(y_test_class, y_test_predict)\n",
"\n",
"# Вывод результатов\n",
"for model_name, results in best_models_class.items():\n",
" print(f\"Results for {model_name}:\")\n",
" print(f\"Precision_train: {results['Precision_train']}\")\n",
" print(f\"Precision_test: {results['Precision_test']}\")\n",
" print(f\"Recall_train: {results['Recall_train']}\")\n",
" print(f\"Recall_test: {results['Recall_test']}\")\n",
" print(f\"Accuracy_train: {results['Accuracy_train']}\")\n",
" print(f\"Accuracy_test: {results['Accuracy_test']}\")\n",
" print(f\"ROC_AUC_test: {results['ROC_AUC_test']}\")\n",
" print(f\"F1_train: {results['F1_train']}\")\n",
" print(f\"F1_test: {results['F1_test']}\")\n",
" print(f\"MCC_test: {results['MCC_test']}\")\n",
" print(f\"Cohen_kappa_test: {results['Cohen_kappa_test']}\")\n",
" print(f\"Confusion_matrix:\\n{results['Confusion_matrix']}\")\n",
" print()\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAc8AAAQ9CAYAAAA/GsaeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAD9Y0lEQVR4nOzdd1gU1xoG8HdpC9J7UVDsYC/RoCigRCyxBHvUgDUmYu83YotKNBqNlcQY0URiEluUJMaKXWPDFoKiqFgADdJD2537B2F0BXQHF3Hh/T3PPDd75szMNyuXj3PmzDkyQRAEEBERkdp0yjsAIiIibcPkSUREJBGTJxERkURMnkRERBIxeRIREUnE5ElERCQRkycREZFETJ5EREQSMXkSERFJxORJFYK3tze8vb01dr4aNWogMDBQY+cjQCaTYe7cueUdBpFGMHmSRoWFhUEmk+HcuXPlHcpLnTx5EnPnzkVKSkqZXqdGjRqQyWTiZmxsjFatWmHz5s1lel0iKjt65R0AkSbs27dP8jEnT57EvHnzEBgYCAsLC5V9MTEx0NHR3N+WTZs2xeTJkwEADx8+xDfffIOAgADk5ORg5MiRGrvOm+zff/+Fnh5/5VDFwJ9kqhAMDAw0ej65XK7R81WtWhWDBw8WPwcGBqJmzZpYvnz5a0+emZmZMDY2fq3XBABDQ8PXfk2issJuWyoXFy9eRJcuXWBmZgYTExN07NgRp0+fLlLv8uXL8PLygpGREapVq4YFCxZg48aNkMlkuH37tlivuGeeq1atQoMGDVClShVYWlqiZcuWCA8PBwDMnTsXU6dOBQC4urqKXaqF5yzumWdKSgomTpyIGjVqQC6Xo1q1avjggw/w+PFjyfdva2uL+vXr4+bNmyrlSqUSK1asQIMGDWBoaAh7e3t8+OGHePLkSZF6c+fOhZOTE6pUqQIfHx/89ddfReIu7EY/cuQIPv74Y9jZ2aFatWri/t9//x3t2rWDsbExTE1N0a1bN1y7dk3lWgkJCRg6dCiqVasGuVwOR0dH9OzZU+X7P3fuHPz8/GBjYwMjIyO4urpi2LBhKucp7pmnOj8Hhfdw4sQJTJo0Cba2tjA2NsZ7772HR48eqfuVE2kUW5702l27dg3t2rWDmZkZpk2bBn19fXz11Vfw9vbGkSNH0Lp1awDA/fv34ePjA5lMhpkzZ8LY2BjffPONWq3C9evXY9y4cejTpw/Gjx+P7OxsXL58GWfOnMH7778Pf39/XL9+HT/88AOWL18OGxsbAAVJrTgZGRlo164doqOjMWzYMDRv3hyPHz/G7t27ce/ePfF4deXn5+PevXuwtLRUKf/www8RFhaGoUOHYty4cYiLi8Pq1atx8eJFnDhxAvr6+gCAmTNnYsmSJejevTv8/Pxw6dIl+Pn5ITs7u9jrffzxx7C1tcXs2bORmZkJAPjuu+8QEBAAPz8/LF68GFlZWVi3bh08PT1x8eJF1KhRAwDQu3dvXLt2DWPHjkWNGjWQlJSE/fv34+7du+LnTp06wdbWFjNmzICFhQVu376NHTt2vPA7UPfnoNDYsWNhaWmJOXPm4Pbt21ixYgWCgoLw448/SvruiTRCINKgjRs3CgCEs2fPllinV69egoGBgXDz5k2x7MGDB4KpqanQvn17sWzs2LGCTCYTLl68KJb9888/gpWVlQBAiIuLE8u9vLwELy8v8XPPnj2FBg0avDDWzz//vMh5ClWvXl0ICAgQP8+ePVsAIOzYsaNIXaVS+cLrVK9eXejUqZPw6NEj4dGjR8KVK1eEIUOGCACEMWPGiPWOHTsmABC2bNmicvzevXtVyhMSEgQ9PT2hV69eKvXmzp0rAFCJu/Dfw9PTU8jPzxfL09PTBQsLC2HkyJEq50hISBDMzc3F8idPnggAhM8//7zE+9u5c+dL/80FQRAACHPmzBE/q/tzUHgPvr6+Kt/1xIkTBV1dXSElJeWF1yUqC+y2pddKoVBg37596NWrF2rWrCmWOzo64v3338fx48eRlpYGANi7dy88PDzQtGlTsZ6VlRUGDRr00utYWFjg3r17OHv2rEbi3r59O5o0aYL33nuvyD6ZTPbS4/ft2wdbW1vY2tqiUaNG+O677zB06FB8/vnnYp2ff/4Z5ubmeOedd/D48WNxa9GiBUxMTHD48GEAwMGDB5Gfn4+PP/5Y5Rpjx44t8fojR46Erq6u+Hn//v1ISUnBwIEDVa6lq6uL1q1bi9cyMjKCgYEBIiMji3QdFyocbBUREYG8vLyXfheAtJ+DQqNGjVL5rtu1aweFQoE7d+6odU0iTWLypNfq0aNHyMrKQr169Yrsc3Nzg1KpRHx8PADgzp07qF27dpF6xZU9b/r06TAxMUGrVq1Qp04djBkzBidOnCh13Ddv3kTDhg1LfXzr1q2xf/9+7N27F0uXLoWFhQWePHmiMtDpxo0bSE1NhZ2dnZhoC7eMjAwkJSUBgJgsnv8erKysinQDF3J1dVX5fOPGDQBAhw4dilxr37594rXkcjkWL16M33//Hfb29mjfvj2WLFmChIQE8VxeXl7o3bs35s2bBxsbG/Ts2RMbN25ETk5Oid+HlJ+DQi4uLiqfC++1pKROVJb4zJMqJDc3N8TExCAiIgJ79+7F9u3bsXbtWsyePRvz5s177fHY2NjA19cXAODn54f69evj3XffxZdffolJkyYBKBgEZGdnhy1bthR7jpKex6rDyMhI5bNSqQRQ8NzTwcGhSP1nXymZMGECunfvjl27duGPP/5AcHAwQkJCcOjQITRr1gwymQzbtm3D6dOnsWfPHvzxxx8YNmwYli1bhtOnT8PExKTUcT/r2ZbzswRB0Mj5iaRg8qTXytbWFlWqVEFMTEyRfX///Td0dHTg7OwMAKhevTpiY2OL1CuurDjGxsbo378/+vfvj9zcXPj7+2PhwoWYOXMmDA0N1epuLVSrVi1cvXpV7fov061bN3h5eWHRokX48MMPYWxsjFq1auHAgQNo27ZtkWT3rOrVqwMo+B6ebVH+888/arfCatWqBQCws7MTk/rL6k+ePBmTJ0/GjRs30LRpUyxbtgzff/+9WOftt9/G22+/jYULFyI8PByDBg3C1q1bMWLEiCLnk/JzQPQmYrctvVa6urro1KkTfvnlF5VXHRITExEeHg5PT0+YmZkBKGihnTp1ClFRUWK95OTkEltmz/rnn39UPhsYGMDd3R2CIIjP5QrfdVRnhqHevXvj0qVL2LlzZ5F9pW35TJ8+Hf/88w/Wr18PAOjXrx8UCgU+/fTTInXz8/PFODt27Ag9PT2sW7dOpc7q1avVvrafnx/MzMywaNGiYp9TFr4CkpWVVWQEb61atWBqaip2yz558qTId1D4nLqkrlspPwdEbyK2PKlMfPvtt9i7d2+R8vHjx2PBggXYv38/PD098fHHH0NPTw9fffUVcnJysGTJErHutGnT8P333+Odd97B2LFjxVdVXFxckJyc/MKWY6dOneDg4IC2bdvC3t4e0dHRWL16Nbp16wZTU1MAQIsWLQAAn3zyCQYMGAB9fX1079692AkEpk6dim3btqFv374YNmwYWrRogeTkZOzevRuhoaFo0qSJ5O+oS5cuaNiwIb744guMGTMGXl5e+PDDDxESEoKoqCh06tQJ+vr6uHHjBn7++Wd8+eWX6NOnD+zt7TF+/HgsW7YMPXr0QOfOnXHp0iX8/vvvsLGxUatFbWZmhnXr1mHIkCFo3rw5BgwYAFtbW9y9exe//vor2rZti9WrV+P69evo2LEj+vXrB3d3d+jp6WHnzp1ITEzEgAEDAACbNm3C2rVr8d5776FWrVpIT0/H+vXrYWZmhq5du5YYg7o/B0RvpPId7EsVTeFrBSVt8fHxgiAIwoULFwQ/Pz/BxMREqFKliuDj4yOcPHmyyPkuXrwotGvXTpDL5UK1atWEkJAQYeXKlQIAISEhQaz3/KsqX331ldC+fXvB2tpakMvlQq1atYSpU6cKqampKuf/9NNPhapVqwo6Ojoqr608/6qKIBS8JhMUFCRUrVpVMDAwEKpVqyYEBAQIjx8/fuF3Ur1
"text/plain": [
"<Figure size 1200x1000 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from matplotlib import pyplot as plt\n",
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"\n",
"# Количество моделей\n",
"num_models = len(best_models_class)\n",
"\n",
"# Создание фигуры и осей для отображения матриц ошибок\n",
"fig, ax = plt.subplots(num_models, 1, figsize=(12, 10), sharex=False, sharey=False)\n",
"\n",
"# Перебор моделей и отображение матриц ошибок\n",
"for index, key in enumerate(best_models_class.keys()):\n",
" c_matrix = best_models_class[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Low\", \"Medium\", \"High\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"# Настройка расположения подзаголовков\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
2024-12-06 22:07:54 +04:00
"execution_count": 18,
2024-12-06 21:40:44 +04:00
"metadata": {},
"outputs": [
{
"data": {
2024-12-06 22:07:54 +04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABOcAAAQ9CAYAAAASxyn4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3xV9f348dfdM3syAoGwt6yICFhRo0VbK9bRqojaqkUctFTtcK+vtFVbVLRW8Ndq3XaI4kCxKsgGBWUKBAnZ497cfe85vz8O995cMkggZMD7+Xjch8k5n5z7uSHe5L7ve+hUVVURQgghhBBCCCGEEEJ0OH1nb0AIIYQQQgghhBBCiJOVBOeEEEIIIYQQQgghhOgkEpwTQgghhBBCCCGEEKKTSHBOCCGEEEIIIYQQQohOIsE5IYQQQgghhBBCCCE6iQTnhBBCCCGEEEIIIYToJBKcE0IIIYQQQgghhBCik0hwTgghhBBCCCGEEEKITiLBOSGEEEIIIYQQQgghOokE54QQQoiT3N69e9HpdCxZsqSztyKOwdVXX01+fn5nb+OEodPpuOeeezp7G53ujDPO4Iwzzoh93hWfLw7foxBCCNHdSHBOCCGEOIEtWbIEnU7HunXrOnsrx80999yDTqeL3UwmE/n5+dx8883U1tZ29vYE8NRTT6HT6SgsLDzqa5SUlHDPPfewadOm9ttYF7dixYpGP9v9+/fnqquu4ttvv+3s7bXJypUrueeee+T/SSGEEKIJxs7egBBCCCE6V9++ffH5fJhMps7eyjF5+umncTqdeDweli9fzl/+8hc2bNjAZ5991tlb6xB//etfURSls7fRpBdffJH8/HzWrFnDrl27GDBgQJuvUVJSwr333kt+fj5jxoxp/012YTfffDMTJkwgFAqxYcMGnn32WZYuXcpXX31Fz549O3QvR/t8sXLlSu69916uvvpqUlNTj8/mhBBCiG5KMueEEEKIk5xOp8NqtWIwGDp7K83yer1HXHPxxRdzxRVXcP311/Pqq69y6aWX8vnnn7NmzZoO2GGcoij4/f4OvU8Ak8mExWLp8Ps9kj179rBy5Ur+9Kc/kZWVxYsvvtjZW+p2pkyZwhVXXMHs2bP5y1/+wh/+8Aeqq6t54YUXmv0aj8dzXPbSHZ4vhBBCiO5GgnNCCCHESa6pHlJXX301TqeTAwcOcOGFF+J0OsnKyuJXv/oVkUgk4esVReHxxx9n+PDhWK1WcnJyuP7666mpqUlY9+9//5sZM2bQs2dPLBYLBQUF3H///Y2ud8YZZzBixAjWr1/P1KlTsdvt/OY3v2nz45oyZQoAu3fvTji+evVqzj33XFJSUrDb7UybNo3PP/+80devWLGC8ePHY7VaKSgo4JlnnomV0Dak0+m46aabePHFFxk+fDgWi4Vly5YBcODAAa655hpycnKwWCwMHz6c559/vtF9/eUvf2H48OHY7XbS0tIYP348L730Uuy82+3m1ltvJT8/H4vFQnZ2NmeffTYbNmyIrWmq55zH4+GXv/wleXl5WCwWBg8ezB/+8AdUVW3yMfzrX/9ixIgRsb1GH0dD27Zto7i4uKlveZNefPFF0tLSmDFjBhdffHGzwbna2lpuu+222GPs3bs3V111FZWVlaxYsYIJEyYAMHv27FiZZ/RnNj8/n6uvvrrRNQ/vRRYMBrnrrrsYN24cKSkpOBwOpkyZwscff9zqxxNVVlaG0Wjk3nvvbXRu+/bt6HQ6Fi5cCEAoFOLee+9l4MCBWK1WMjIyOP300/nggw/afL8AZ555JqAFPiFe2v3111/zk5/8hLS0NE4//fTY+n/84x+MGzcOm81Geno6l112Gfv372903WeffZaCggJsNhsTJ07k008/bbSmuZ5z27Zt45JLLiErKwubzcbgwYP57W9/G9vf/PnzAejXr1/s32/v3r3HZY9CCCFEdyNlrUIIIYRoUiQSoaioiMLCQv7whz/w4Ycf8sc//pGCggJuvPHG2Lrrr7+eJUuWMHv2bG6++Wb27NnDwoUL2bhxI59//nms/G3JkiU4nU7mzZuH0+nko48+4q677sLlcrFgwYKE+66qquK8887jsssu44orriAnJ6fN+4++8E9LS4sd++ijjzjvvPMYN24cd999N3q9nsWLF3PmmWfy6aefMnHiRAA2btzIueeeS48ePbj33nuJRCLcd999ZGVlNXlfH330Ea+++io33XQTmZmZ5OfnU1ZWxqmnnhoLfGVlZfHuu+9y7bXX4nK5uPXWWwGtHPXmm2/m4osv5pZbbsHv9/Pll1+yevVqfvKTnwBwww038Prrr3PTTTcxbNgwqqqq+Oyzz/jmm28YO3Zsk3tSVZUf/OAHfPzxx1x77bWMGTOG9957j/nz53PgwAEee+yxhPWfffYZb775Jr/4xS9ISkriz3/+MzNnzqS4uJiMjIzYuqFDhzJt2jRWrFjRqn+HF198kYsuugiz2czll1/O008/zdq1a2PBNoD6+nqmTJnCN998wzXXXMPYsWOprKzkP//5D9999x1Dhw7lvvvu46677uLnP/95LPB62mmntWoPUS6Xi+eee47LL7+cn/3sZ7jdbv72t79RVFTEmjVr2lQum5OTw7Rp03j11Ve5++67E8698sorGAwGfvzjHwNacOrhhx/muuuuY+LEibhcLtatW8eGDRs4++yz2/QYIB5wbvjvAvDjH/+YgQMH8tBDD8UCsA8++CC///3vueSSS7juuuuoqKjgL3/5C1OnTmXjxo2xEtO//e1vXH/99Zx22mnceuutfPvtt/zgBz8gPT2dvLy8Fvfz5ZdfMmXKFEwmEz//+c/Jz89n9+7d/Pe//+XBBx/koosuYseOHfzzn//kscceIzMzEyD2/1NH7FEIIYTo0lQhhBBCnLAWL16sAuratWubXbNnzx4VUBcvXhw7NmvWLBVQ77vvvoS1p5xyijpu3LjY559++qkKqC+++GLCumXLljU67vV6G9339ddfr9rtdtXv98eOTZs2TQXURYsWteox3n333Sqgbt++Xa2oqFD37t2rPv/886rNZlOzsrJUj8ejqqqqKoqiDhw4UC0qKlIVRUnYV79+/dSzzz47duyCCy5Q7Xa7euDAgdixnTt3qkajUT38zydA1ev16tatWxOOX3vttWqPHj3UysrKhOOXXXaZmpKSEvt+/PCHP1SHDx/e4mNMSUlR58yZ0+KaWbNmqX379o19/q9//UsF1AceeCBh3cUXX6zqdDp1165dCY/BbDYnHNu8ebMKqH/5y18aPd5p06a1uJeodevWqYD6wQcfqKqq/Rv07t1bveWWWxLW3XXXXSqgvvnmm42uEf23Wrt2baOf06i+ffuqs2bNanR82rRpCXsNh8NqIBBIWFNTU6Pm5OSo11xzTcJxQL377rtbfHzPPPOMCqhfffVVwvFhw4apZ555Zuzz0aNHqzNmzGjxWk35+OOPVUB9/vnn1YqKCrWkpERdunSpmp+fr+p0utj/19H/By6//PKEr9+7d69qMBjUBx98MOH4V199pRqNxtjxYDCoZmdnq2PGjEn4/jz77LON/r2ber6YOnWqmpSUpO7bty/hfhr+f7ZgwQIVUPfs2XPc9yiEEEJ0N1LWKoQQQohm3XDDDQmfT5kyJWFK5GuvvUZKSgpnn302lZWVsdu4ceNwOp0J5YI2my32sdvtprKykilTpuD1etm2bVvC/VgsFmbPnt2mvQ4ePJisrCzy8/O55pprGDBgAO+++y52ux2ATZs2sXPnTn7yk59QVVUV26vH42H69On873//Q1EUIpEIH374IRdeeGFCs/0BAwZw3nnnNXnf06ZNY9iwYbHPVVXljTfe4IILLkBV1YTvTVFREXV1dbGS1NTUVL777jvWrl3b7GNLTU1l9erVlJSUtPr78c4772AwGLj55psTjv/yl79EVVXefffdhONnnXUWBQUFsc9HjRpFcnJyo6mgqqq2KWsuJyeH733ve4BWPnvppZfy8ssvJ5Qzv/HGG4wePZof/ehHja5xeBnxsTAYDJjNZkArx66uriYcDjN+/PiEEuHWuui
2024-12-06 21:40:44 +04:00
"text/plain": [
2024-12-06 22:07:54 +04:00
"<Figure size 1500x1000 with 3 Axes>"
2024-12-06 21:40:44 +04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from matplotlib import pyplot as plt\n",
"from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
"\n",
"# Создание фигуры и осей для отображения графиков\n",
2024-12-06 22:07:54 +04:00
"fig, ax = plt.subplots(len(best_models_reg), 1, figsize=(15, 10), sharex=False, sharey=False)\n",
2024-12-06 21:40:44 +04:00
"ax = ax.flatten()\n",
"\n",
"# Перебор моделей и отображение графиков\n",
"for index, (name, model_results) in enumerate(results_reg.items()):\n",
" y_test_reg = y_test_reg\n",
" y_pred_reg = model_results[\"preds_test\"]\n",
"\n",
" # График фактических значений против предсказанных значений\n",
2024-12-06 22:07:54 +04:00
" ax[index].scatter(y_test_reg, y_pred_reg, alpha=0.5)\n",
" ax[index].plot([min(y_test_reg), max(y_test_reg)], [min(y_test_reg), max(y_test_reg)], color='red', linestyle='--')\n",
" ax[index].set_xlabel('Actual Values')\n",
" ax[index].set_ylabel('Predicted Values')\n",
" ax[index].set_title(f'{name}: Actual vs Predicted')\n",
2024-12-06 21:40:44 +04:00
"\n",
"# Настройка расположения подзаголовков\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "aimenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}