AIM-PIbd-31-Rodionov-I-A/lab_3/lab3.ipynb

1031 lines
115 KiB
Plaintext
Raw Normal View History

2024-11-01 23:33:34 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Данные по инсультам\n",
"\n",
"Выведем информацию о столбцах датасета:"
]
},
{
"cell_type": "code",
"execution_count": 441,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['id', 'gender', 'age', 'hypertension', 'heart_disease', 'ever_married',\n",
" 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi',\n",
" 'smoking_status', 'stroke'],\n",
" dtype='object')\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>9046</td>\n",
" <td>Male</td>\n",
" <td>67.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>228.69</td>\n",
" <td>36.6</td>\n",
" <td>formerly smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>51676</td>\n",
" <td>Female</td>\n",
" <td>61.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>202.21</td>\n",
" <td>NaN</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>31112</td>\n",
" <td>Male</td>\n",
" <td>80.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>105.92</td>\n",
" <td>32.5</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>60182</td>\n",
" <td>Female</td>\n",
" <td>49.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>171.23</td>\n",
" <td>34.4</td>\n",
" <td>smokes</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1665</td>\n",
" <td>Female</td>\n",
" <td>79.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>174.12</td>\n",
" <td>24.0</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id gender age hypertension heart_disease ever_married \\\n",
"0 9046 Male 67.0 0 1 Yes \n",
"1 51676 Female 61.0 0 0 Yes \n",
"2 31112 Male 80.0 0 1 Yes \n",
"3 60182 Female 49.0 0 0 Yes \n",
"4 1665 Female 79.0 1 0 Yes \n",
"\n",
" work_type Residence_type avg_glucose_level bmi smoking_status \\\n",
"0 Private Urban 228.69 36.6 formerly smoked \n",
"1 Self-employed Rural 202.21 NaN never smoked \n",
"2 Private Rural 105.92 32.5 never smoked \n",
"3 Private Urban 171.23 34.4 smokes \n",
"4 Self-employed Rural 174.12 24.0 never smoked \n",
"\n",
" stroke \n",
"0 1 \n",
"1 1 \n",
"2 1 \n",
"3 1 \n",
"4 1 "
]
},
"execution_count": 441,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"from imblearn.over_sampling import RandomOverSampler\n",
"from sklearn.preprocessing import StandardScaler\n",
"import featuretools as ft\n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.model_selection import cross_val_score\n",
"import time\n",
"from sklearn.metrics import root_mean_squared_error, r2_score, mean_absolute_error\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"\n",
"df = pd.read_csv(\"..//..//static//csv//healthcare-dataset-stroke-data.csv\")\n",
"\n",
"print(df.columns)\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Определим бизнес цели и цели технического проекта.\n",
"\n",
"1. Улучшение диагностики и профилактики инсульта.\n",
" * Бизнес-цель: повышение точности прогнозирования риска инсульта среди пациентов для более раннего лечебного вмешательства. Определение основных факторов риска для более целенаправленного подхода в медицинском обслуживании.\n",
" * Цель технического проекта: разработка статистической модели, которая решает задачу классификации и предсказывает возможность возникновения инсульта у пациентов на основе имеющихся данных (возраст, гипертония, заболевания сердца и пр.), с целью выявления групп риска. Внедрение этой модели в систему поддержки принятия медицинских решений для врачей.\n",
"2. Снижение расходов на лечение инсультов.\n",
" * Бизнес-цель: снижение затрат на лечение инсульта путем более эффективного распределения медицинских ресурсов и направленных профилактических мер.\n",
" * Цель технического проекта: создание системы оценки индивидуального риска инсульта для пациентов, что позволит медучреждениям проводить профилактические меры среди целевых групп, сокращая расходы на лечение.\n",
"\n",
"### И теперь проверим датасет на пустые значения:"
]
},
{
"cell_type": "code",
"execution_count": 442,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"id 0\n",
"gender 0\n",
"age 0\n",
"hypertension 0\n",
"heart_disease 0\n",
"ever_married 0\n",
"work_type 0\n",
"Residence_type 0\n",
"avg_glucose_level 0\n",
"bmi 201\n",
"smoking_status 0\n",
"stroke 0\n",
"dtype: int64\n",
"\n",
"id False\n",
"gender False\n",
"age False\n",
"hypertension False\n",
"heart_disease False\n",
"ever_married False\n",
"work_type False\n",
"Residence_type False\n",
"avg_glucose_level False\n",
"bmi True\n",
"smoking_status False\n",
"stroke False\n",
"dtype: bool\n",
"\n",
"bmi процент пустых значений: %3.93\n"
]
}
],
"source": [
"# Количество пустых значений признаков\n",
"print(df.isnull().sum())\n",
"\n",
"print()\n",
"\n",
"# Есть ли пустые значения признаков\n",
"print(df.isnull().any())\n",
"\n",
"print()\n",
"\n",
"# Процент пустых значений признаков\n",
"for i in df.columns:\n",
" null_rate = df[i].isnull().sum() / len(df) * 100\n",
" if null_rate > 0:\n",
" print(f\"{i} процент пустых значений: %{null_rate:.2f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В столбце bmi можно заметить пустые значение. Заменим их на медиану:"
]
},
{
"cell_type": "code",
"execution_count": 443,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Количество пустых значений в каждом столбце после замены:\n",
"id 0\n",
"gender 0\n",
"age 0\n",
"hypertension 0\n",
"heart_disease 0\n",
"ever_married 0\n",
"work_type 0\n",
"Residence_type 0\n",
"avg_glucose_level 0\n",
"bmi 0\n",
"smoking_status 0\n",
"stroke 0\n",
"dtype: int64\n"
]
}
],
"source": [
"# Замена значений\n",
"df[\"bmi\"] = df[\"bmi\"].fillna(df[\"bmi\"].median())\n",
"\n",
"# Проверка на пропущенные значения после замены\n",
"missing_values_after_drop = df.isnull().sum()\n",
"\n",
"# Вывод результатов после замены\n",
"print(\"\\nКоличество пустых значений в каждом столбце после замены:\")\n",
"print(missing_values_after_drop)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Можно перейти к созданию выборок"
]
},
{
"cell_type": "code",
"execution_count": 444,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Размер обучающей выборки: (2503, 11)\n",
"Размер контрольной выборки: (1074, 11)\n",
"Размер тестовой выборки: (1533, 11)\n"
]
}
],
"source": [
"# Разделение данных на признаки (X) и целевую переменную (y)\n",
"# В данном случае мы хотим предсказать 'stroke'\n",
"X = df.drop(columns=['stroke'])\n",
"y = df['stroke']\n",
"\n",
"# Разбиение данных на обучающую и тестовую выборки\n",
"# Сначала разделим на обучающую и тестовую\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)\n",
"\n",
"# Затем разделим обучающую выборку на обучающую и контрольную\n",
"X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.3)\n",
"\n",
"# Проверка размеров выборок\n",
"print(\"Размер обучающей выборки:\", X_train.shape)\n",
"print(\"Размер контрольной выборки:\", X_val.shape)\n",
"print(\"Размер тестовой выборки:\", X_test.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Оценим сбалансированность выборок:"
]
},
{
"cell_type": "code",
"execution_count": 445,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Распределение классов в обучающей выборке:\n",
"stroke\n",
"0 0.955653\n",
"1 0.044347\n",
"Name: proportion, dtype: float64\n",
"\n",
"Распределение классов в контрольной выборке:\n",
"stroke\n",
"0 0.954376\n",
"1 0.045624\n",
"Name: proportion, dtype: float64\n",
"\n",
"Распределение классов в тестовой выборке:\n",
"stroke\n",
"0 0.941944\n",
"1 0.058056\n",
"Name: proportion, dtype: float64\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABboAAAHyCAYAAAAtJXgGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABpfElEQVR4nO3dd3RU1d7G8ScJ6SEgpFJM6KGDoQiItECA0PQKUq4EVEABFfCigkoAS0QUQUApChbwiqDgVZQqKGIEAVGQIiVID4QWagKZ/f7ByrwMM4FQZHLw+1kra2X27HPO78xkZp95cmYfD2OMEQAAAAAAAAAAFuXp7gIAAAAAAAAAALgRBN0AAAAAAAAAAEsj6AYAAAAAAAAAWBpBNwAAAAAAAADA0gi6AQAAAAAAAACWRtANAAAAAAAAALA0gm4AAAAAAAAAgKURdAMAAAAAAAAALI2gGwAAAACug81mU3p6unbu3OnuUgAAAP7xCLoBAAAAII8OHjyoAQMGKCoqSj4+PgoNDVWlSpWUkZHh7tIAAAD+0Qq4uwAAAICb7YMPPlDPnj3tt319fXXnnXeqRYsWevHFFxUeHu7G6gBY1fbt29WkSROdP39eTz75pO666y4VKFBA/v7+CgwMdHd5AAAA/2gE3QAA4LY1cuRIlSpVSufOndOPP/6od999V9988402btyogIAAd5cHwGL69OkjHx8f/fzzzypevLi7ywEAAMAlCLoBAMBtq1WrVqpVq5Yk6dFHH1XRokU1ZswYffnll+rSpYubqwNgJWvXrtV3332nRYsWEXIDAADkQ8zRDQAA/jGaNm0qSUpNTZUkHT16VP/5z39UtWpVBQUFKTg4WK1atdJvv/3mtOy5c+c0fPhwlS9fXn5+foqMjNT999+vHTt2SJJ27dolDw+PXH8aN25sX9fy5cvl4eGhWbNmaejQoYqIiFBgYKDatWunPXv2OG171apVatmypQoVKqSAgAA1atRIK1eudLmPjRs3drn94cOHO/WdMWOGYmNj5e/vryJFiqhz584ut3+lfbuUzWbT2LFjVblyZfn5+Sk8PFx9+vTRsWPHHPpFR0erTZs2Ttvp37+/0zpd1T569Ginx1SSMjMzlZSUpLJly8rX11clS5bUM888o8zMTJeP1aUuf9xCQkKUkJCgjRs35mnZKlWqaO3atapfv778/f1VqlQpTZo0yaFfVlaWhg0bptjYWBUqVEiBgYFq2LChli1b5tBv69atatq0qSIiIuz78dhjj+no0aNO2+7Ro8dVn+8ePXooOjraYbk9e/bI399fHh4e2rVrl6T/f54/+OADh77Dhw93+bz079/fqZ42bdo4bCtnnW+88UYuj57z+qdPny4PDw9NmzbNod+rr74qDw8PffPNN7muS7r495XzOHh6eioiIkIPPvigdu/efUN1/fzzz/Lz89OOHTtUuXJl+fr6KiIiQn369HH53MyePdv++goJCdG///1v7du3z6FPjx49FBQUpJ07dyo+Pl6BgYEqVqyYRo4cKWOMU72XPjcnT55UbGysSpUqpQMHDtjb33jjDdWvX19FixaVv7+/YmNjNWfOHIft3uhjDAAAkB9xRjcAAPjHyAmlixYtKknauXOn5s2bp44dO6pUqVJKS0vT5MmT1ahRI23atEnFihWTJGVnZ6tNmzZaunSpOnfurKeeekonT57U4sWLtXHjRpUpU8a+jS5duqh169YO2x0yZIjLel555RV5eHjo2Wef1aFDhzR27FjFxcVp/fr18vf3lyR99913atWqlWJjY5WUlCRPT09Nnz5dTZs21YoVK1SnTh2n9ZYoUULJycmSpFOnTunxxx93ue0XX3xRnTp10qOPPqrDhw9r/Pjxuvfee/Xrr7+qcOHCTsv07t1bDRs2lCR98cUXmjt3rsP9ffr0sc+P/uSTTyo1NVUTJkzQr7/+qpUrV8rb29vl43Atjh8/bt+3S9lsNrVr104//vijevfurYoVK2rDhg1666239Oeff2revHlXXXdMTIyef/55GWO0Y8cOjRkzRq1bt3YISHNz7NgxtW7dWp06dVKXLl302Wef6fHHH5ePj48efvhhSVJGRobee+89denSRb169dLJkyf1/vvvKz4+XqtXr1aNGjUkSadPn1aJEiXUtm1bBQcHa+PGjZo4caL27dunr776ymnbISEheuutt+y3H3rooavWO2zYMJ07d+6q/dyhZ8+e+uKLLzRo0CA1b95cJUuW1IYNGzRixAg98sgjTq8vVxo2bKjevXvLZrNp48aNGjt2rPbv368VK1Zcd11HjhzRuXPn9Pjjj6tp06Z67LHHtGPHDk2cOFGrVq3SqlWr5OvrK+n/rxNQu3ZtJScnKy0tTePGjdPKlSudXl/Z2dlq2bKl7r77br3++utasGCBkpKSdOHCBY0cOdJlLefPn9e//vUv7d69WytXrlRkZKT9vnHjxqldu3bq1q2bsrKy9Omnn6pjx476+uuvlZCQcNMeYwAAgHzHAAAA3GamT59uJJklS5aYw4cPmz179phPP/3UFC1a1Pj7+5u9e/caY4w5d+6cyc7Odlg2NTXV+Pr6mpEjR9rbpk2bZiSZMWPGOG3LZrPZl5NkRo8e7dSncuXKplGjRvbby5YtM5JM8eLFTUZGhr39s88+M5LMuHHj7OsuV66ciY+Pt2/HGGPOnDljSpUqZZo3b+60rfr165sqVarYbx8+fNhIMklJSfa2Xbt2GS8vL/PKK684LLthwwZToEABp/Zt27YZSebDDz+0tyUlJZlLDyVXrFhhJJmZM2c6LLtgwQKn9qioKJOQkOBUe79+/czlh6eX1/7MM8+YsLAwExsb6/CYfvzxx8bT09OsWLHCYflJkyYZSWblypVO27tUo0aNHNZnjDFDhw41ksyhQ4euuqwk8+abb9rbMjMzTY0aNUxYWJjJysoyxhhz4cIFk5mZ6bDssWPHTHh4uHn44YevuI2+ffuaoKAgp/Zu3bqZUqVKObRd/pglJiaaqKgo++2NGzcaT09P06pVKyPJpKamGmOM+euvv4wkM23aNIf1Xf5c52yjX79+TvUkJCQ4bOtKr4srrf/AgQOmSJEipnnz5iYzM9PUrFnT3HnnnebEiRO5ridHVFSUSUxMdGjr2rWrCQgIuKG6cm43a9bMXLhwwd6e834zfvx4Y4wxWVlZJiwszFSpUsWcPXvW3u/rr782ksywYcPsbYmJiUaSeeKJJ+xtNpvNJCQkGB8fH3P48GGHeqdPn25sNpvp1q2bCQgIMKtWrXKq+8yZMw63s7KyTJUqVUzTpk0d2m/kMQYAAMiPmLoEAADctuLi4hQaGqqSJUuqc+fOCgoK0ty5c+3z6/r6+srT8+LhUHZ2to4cOaKgoCBVqFBB69ats6/n888/V0hIiJ544gmnbVw+pcO16N69uwoWLGi//cADDygyMtI+bcD69eu1bds2de3aVUeOHFF6errS09N1+vRpNWvWTD/88INsNpvDOs+dOyc/P78rbveLL76QzWZTp06d7OtMT09XRESEypUr5zSVRlZWliTZz1Z1Zfbs2SpUqJCaN2/usM7Y2FgFBQU5rfP8+fMO/dLT0696hvG+ffs0fvx4vfjiiwoKCnLafsWKFRUTE+Owzpzpai7fvis5NR0+fFgpKSmaO3euqlWrppCQkKsuW6BAAfXp08d+28fHR3369NGhQ4e0du1aSZKXl5d8fHwkXTwD/ejRo7pw4YJq1arl8PeW48SJE0pLS9PSpUs1f/583XvvvU59srKyrvi8uDJkyBDddddd6tixo0N7aGioJGnv3r15Ws+5c+ecnsPz58+77HvmzBmlp6fr2LFjDlNy5CYiIkITJ07U4sWL1bBhQ61fv17Tpk1TcHBwnmrLzMxUenq6Dh06pMWLF+u7775Ts2bNbrguSRo0aJC8vLzstx966CGFh4dr/vz5kqQ1a9bo0KFD6tu3r8NrMSEhQTExMfZ+l7p0GpicaWGysrK0ZMkSp76DBw/WzJkz9dlnn7n8RkfOt0Gki980OHHihBo2bOj0N3ajjzEAAEB+w9QlAADgtjV
"text/plain": [
"<Figure size 1800x500 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Функция для анализа сбалансированности\n",
"def analyze_balance(y_train, y_val, y_test, y_name):\n",
" # Распределение классов\n",
" print(\"Распределение классов в обучающей выборке:\")\n",
" print(y_train.value_counts(normalize=True))\n",
" \n",
" print(\"\\nРаспределение классов в контрольной выборке:\")\n",
" print(y_val.value_counts(normalize=True))\n",
" \n",
" print(\"\\nРаспределение классов в тестовой выборке:\")\n",
" print(y_test.value_counts(normalize=True))\n",
"\n",
" # Создание фигуры и осей для трех столбчатых диаграмм\n",
" fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True)\n",
" fig.suptitle('Распределение в различных выборках')\n",
"\n",
" # Обучающая выборка\n",
" sns.barplot(x=y_train.value_counts().index, y=y_train.value_counts(normalize=True), ax=axes[0])\n",
" axes[0].set_title('Обучающая выборка')\n",
" axes[0].set_xlabel(y_name)\n",
" axes[0].set_ylabel('Доля')\n",
"\n",
" # Контрольная выборка\n",
" sns.barplot(x=y_val.value_counts().index, y=y_val.value_counts(normalize=True), ax=axes[1])\n",
" axes[1].set_title('Контрольная выборка')\n",
" axes[1].set_xlabel(y_name)\n",
"\n",
" # Тестовая выборка\n",
" sns.barplot(x=y_test.value_counts().index, y=y_test.value_counts(normalize=True), ax=axes[2])\n",
" axes[2].set_title('Тестовая выборка')\n",
" axes[2].set_xlabel(y_name)\n",
"\n",
" plt.show()\n",
"\n",
"analyze_balance(y_train, y_val, y_test, 'stroke')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Легко заметить, что выборки несбалансированны. Необходимо сбалансировать обучающую и контрольную выборки, чтобы получить лучшие результаты при обучении модели. Для балансировки применим RandomOverSampler:"
]
},
{
"cell_type": "code",
"execution_count": 446,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Распределение классов в обучающей выборке:\n",
"stroke\n",
"0 0.5\n",
"1 0.5\n",
"Name: proportion, dtype: float64\n",
"\n",
"Распределение классов в контрольной выборке:\n",
"stroke\n",
"0 0.5\n",
"1 0.5\n",
"Name: proportion, dtype: float64\n",
"\n",
"Распределение классов в тестовой выборке:\n",
"stroke\n",
"0 0.941944\n",
"1 0.058056\n",
"Name: proportion, dtype: float64\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABboAAAHyCAYAAAAtJXgGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABn9klEQVR4nO3dd3RU1d7G8ScJqYQipFJMaBKKFEMRMNJCAgQBvYKUKwEUkKaCFxVUAlgioggCSlGwgFcEFa+NKihiAAFpgkgJ0gOhhR7I7PcPVuZlmAkEiEwOfD9rZa3Mnn3O+Z0zmdlnnpzZ42GMMQIAAAAAAAAAwKI83V0AAAAAAAAAAAA3gqAbAAAAAAAAAGBpBN0AAAAAAAAAAEsj6AYAAAAAAAAAWBpBNwAAAAAAAADA0gi6AQAAAAAAAACWRtANAAAAAAAAALA0gm4AAAAAAAAAgKURdAMAAADAdbDZbEpPT9eOHTvcXQoAAMBtj6AbAAAAAHLpwIEDevrppxURESEfHx8FBwercuXKysjIcHdpAAAAt7UC7i4AAAAgr3344Yfq1q2b/bavr6/uvPNOxcXF6aWXXlJoaKgbqwNgVdu2bVPjxo11/vx5Pfnkk7rnnntUoEAB+fv7q2DBgu4uDwAA4LZG0A0AAG5ZI0aMUJkyZXT27Fn98ssveu+99/T9999r48aNCggIcHd5ACymV69e8vHx0fLly1WyZEl3lwMAAIBLEHQDAIBbVosWLVSrVi1J0uOPP67ixYtr9OjR+vrrr9WxY0c3VwfASlavXq0ff/xR8+fPJ+QGAADIh5ijGwAA3DaaNGkiSUpNTZUkHTlyRP/5z3909913KzAwUIULF1aLFi20bt06p2XPnj2rYcOG6a677pKfn5/Cw8P10EMPafv27ZKknTt3ysPDI8efRo0a2de1ZMkSeXh4aObMmRoyZIjCwsJUsGBBtW7dWrt373ba9ooVK9S8eXMVKVJEAQEBatiwoZYtW+ZyHxs1auRy+8OGDXPqO336dEVHR8vf31/FihVThw4dXG7/Svt2KZvNpjFjxqhKlSry8/NTaGioevXqpaNHjzr0i4yMVKtWrZy2069fP6d1uqp91KhRTsdUks6dO6ekpCSVL19evr6+Kl26tJ599lmdO3fO5bG61OXHLSgoSAkJCdq4cWOulq1atapWr16t+vXry9/fX2XKlNHEiRMd+mVmZmro0KGKjo5WkSJFVLBgQcXExGjx4sUO/bZs2aImTZooLCzMvh9PPPGEjhw54rTtrl27XvXx7tq1qyIjIx2W2717t/z9/eXh4aGdO3dK+v/H+cMPP3ToO2zYMJePS79+/ZzqadWqlcO2stf55ptv5nD0nNc/bdo0eXh4aOrUqQ79XnvtNXl4eOj777/PcV3Sxb+v7OPg6empsLAwPfLII9q1a9cN1bV8+XL5+flp+/btqlKlinx9fRUWFqZevXq5fGxmzZplf34FBQXp3//+t/bu3evQp2vXrgoMDNSOHTsUHx+vggULqkSJEhoxYoSMMU71XvrYnDhxQtHR0SpTpoz2799vb3/zzTdVv359FS9eXP7+/oqOjtbs2bMdtnujxxgAACA/4opuAABw28gOpYsXLy5J2rFjh+bMmaN27dqpTJkySktL06RJk9SwYUNt2rRJJUqUkCRlZWWpVatWWrRokTp06KCnnnpKJ06c0IIFC7Rx40aVK1fOvo2OHTuqZcuWDtsdPHiwy3peffVVeXh46LnnntPBgwc1ZswYxcbGau3atfL395ck/fjjj2rRooWio6OVlJQkT09PTZs2TU2aNNHSpUtVp04dp/WWKlVKycnJkqSTJ0+qd+/eLrf90ksvqX379nr88cd16NAhjRs3Tvfff79+//13FS1a1GmZnj17KiYmRpL05Zdf6quvvnK4v1evXvb50Z988kmlpqZq/Pjx+v3337Vs2TJ5e3u7PA7X4tixY/Z9u5TNZlPr1q31yy+/qGfPnqpUqZI2bNigt99+W3/99ZfmzJlz1XVHRUXphRdekDFG27dv1+jRo9WyZUuHgDQnR48eVcuWLdW+fXt17NhRn3/+uXr37i0fHx91795dkpSRkaH3339fHTt2VI8ePXTixAl98MEHio+P18qVK1WjRg1J0qlTp1SqVCk98MADKly4sDZu3KgJEyZo7969+uabb5y2HRQUpLffftt++9FHH71qvUOHDtXZs2ev2s8dunXrpi+//FIDBw5Us2bNVLp0aW3YsEHDhw/XY4895vT8ciUmJkY9e/aUzWbTxo0bNWbMGO3bt09Lly697roOHz6ss2fPqnfv3mrSpImeeOIJbd++XRMmTNCKFSu0YsUK+fr6Svr/7wmoXbu2kpOTlZaWprFjx2rZsmVOz6+srCw1b95c9957r9544w3NnTtXSUlJunDhgkaMGOGylvPnz+tf//qXdu3apWXLlik8PNx+39ixY9W6dWt17txZmZmZ+uyzz9SuXTt9++23SkhIyLNjDAAAkO8YAACAW8y0adOMJLNw4UJz6NAhs3v3bvPZZ5+Z4sWLG39/f7Nnzx5jjDFnz541WVlZDsumpqYaX19fM2LECHvb1KlTjSQzevRop23ZbDb7cpLMqFGjnPpUqVLFNGzY0H578eLFRpIpWbKkycjIsLd//vnnRpIZO3asfd0VKlQw8fHx9u0YY8zp06dNmTJlTLNmzZy2Vb9+fVO1alX77UOHDhlJJikpyd62c+dO4+XlZV599VWHZTds2GAKFCjg1L5161YjyXz00Uf2tqSkJHPpqeTSpUuNJDNjxgyHZefOnevUHhERYRISEpxq79u3r7n89PTy2p999lkTEhJioqOjHY7pJ598Yjw9Pc3SpUsdlp84caKRZJYtW+a0vUs1bNjQYX3GGDNkyBAjyRw8ePCqy0oyb731lr3t3LlzpkaNGiYkJMRkZmYaY4y5cOGCOXfunMOyR48eNaGhoaZ79+5X3EafPn1MYGCgU3vnzp1NmTJlHNouP2aJiYkmIiLCfnvjxo3G09PTtGjRwkgyqampxhhj/v77byPJTJ061WF9lz/W2dvo27evUz0JCQkO27rS8+JK69+/f78pVqyYadasmTl37pypWbOmufPOO83x48dzXE+2iIgIk5iY6NDWqVMnExAQcEN1Zd9u2rSpuXDhgr09+/Vm3LhxxhhjMjMzTUhIiKlatao5c+aMvd+3335rJJmhQ4fa2xITE40k079/f3ubzWYzCQkJxsfHxxw6dMih3mnTphmbzWY6d+5sAgICzIoVK5zqPn36tMPtzMxMU7VqVdOkSROH9hs5xgAAAPkRU5cAAIBbVmxsrIKDg1W6dGl16NBBgYGB+uqrr+zz6/r6+srT8+LpUFZWlg4fPqzAwEBVrFhRa9assa/niy++UFBQkPr37++0jcundLgWXbp0UaFChey3H374YYWHh9unDVi7dq22bt2qTp066fDhw0pPT1d6erpOnTqlpk2b6ueff5bNZnNY59mzZ+Xn53fF7X755Zey2Wxq3769fZ3p6ekKCwtThQoVnKbSyMzMlCT71aquzJo1S0WKFFGzZs0c1hkdHa3AwECndZ4/f96hX3p6+lWvMN67d6/GjRunl156SYGBgU7br1SpkqKiohzWmT1dzeXbdyW7pkOHDiklJUVfffWVqlWrpqCgoKsuW6BAAfXq1ct+28fHR7169dLBgwe1evVqSZKXl5d8fHwkXbwC/ciRI7pw4YJq1arl8PeW7fjx40pLS9OiRYv03Xff6f7773fqk5mZecXHxZXBgwfrnnvuUbt27Rzag4ODJUl79uzJ1XrOnj3r9BieP3/eZd/Tp08rPT1dR48edZiSIydhYWGaMGGCFixYoJiYGK1du1ZTp05V4cKFc1XbuXPnlJ6eroMHD2rBggX68ccf1bRp0xuuS5IGDhwoLy8v++1HH31UoaGh+u677yRJq1at0sGDB9WnTx+H52JCQoKioqLs/S516TQw2dPCZGZmauHChU59Bw0apBkzZujzzz93+YmO7E+DSBc/aXD8+HHFxMQ4/Y3d6DEGAADIb5i6BAAA3LI
"text/plain": [
"<Figure size 1800x500 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ros = RandomOverSampler(random_state=42)\n",
"\n",
"# Применение RandomOverSampler для балансировки выборок\n",
"X_train_resampled, y_train_resampled = ros.fit_resample(X_train, y_train)\n",
"X_val_resampled, y_val_resampled = ros.fit_resample(X_val, y_val)\n",
"\n",
"# Проверка сбалансированности после RandomOverSampler\n",
"analyze_balance(y_train_resampled, y_val_resampled, y_test, 'stroke')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выборки сбалансированы.\n",
"\n",
"### Перейдем к конструированию признаков\n",
"\n",
"Для начала применим унитарное кодирование категориальных признаков (one-hot encoding), переведя их в бинарные вектора:"
]
},
{
"cell_type": "code",
"execution_count": 447,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" id age hypertension heart_disease avg_glucose_level bmi \\\n",
"0 16605 57.0 0 0 106.24 32.3 \n",
"1 12015 14.0 0 0 99.87 25.2 \n",
"2 26474 44.0 0 0 97.16 33.1 \n",
"3 31143 22.0 0 0 107.52 41.6 \n",
"4 2447 63.0 0 0 85.04 29.7 \n",
"\n",
" gender_Male gender_Other ever_married_Yes work_type_Never_worked \\\n",
"0 True False True False \n",
"1 True False False False \n",
"2 False False True False \n",
"3 False False False False \n",
"4 False False True False \n",
"\n",
" work_type_Private work_type_Self-employed work_type_children \\\n",
"0 True False False \n",
"1 False False True \n",
"2 False False False \n",
"3 True False False \n",
"4 True False False \n",
"\n",
" Residence_type_Urban smoking_status_formerly smoked \\\n",
"0 True False \n",
"1 True False \n",
"2 True False \n",
"3 False False \n",
"4 True True \n",
"\n",
" smoking_status_never smoked smoking_status_smokes \n",
"0 True False \n",
"1 False False \n",
"2 False False \n",
"3 False False \n",
"4 False False \n"
]
}
],
"source": [
"# Определение категориальных признаков\n",
"categorical_features = ['gender', 'ever_married', 'work_type', 'Residence_type', 'smoking_status']\n",
"\n",
"# Применение one-hot encoding к обучающей выборке\n",
"X_train_encoded = pd.get_dummies(X_train_resampled, columns=categorical_features, drop_first=True)\n",
"\n",
"# Применение one-hot encoding к контрольной выборке\n",
"X_val_encoded = pd.get_dummies(X_val_resampled, columns=categorical_features, drop_first=True)\n",
"\n",
"# Применение one-hot encoding к тестовой выборке\n",
"X_test_encoded = pd.get_dummies(X_test, columns=categorical_features, drop_first=True)\n",
"\n",
"print(X_train_encoded.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Далее к числовым признакам, а именно к колонке age, применим дискретизацию (позволяет преобразовать данные из числового представления в категориальное):"
]
},
{
"cell_type": "code",
"execution_count": 448,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" id hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"0 16605 0 0 106.24 32.3 True \n",
"1 12015 0 0 99.87 25.2 True \n",
"2 26474 0 0 97.16 33.1 False \n",
"3 31143 0 0 107.52 41.6 False \n",
"4 2447 0 0 85.04 29.7 False \n",
"\n",
" gender_Other ever_married_Yes work_type_Never_worked work_type_Private \\\n",
"0 False True False True \n",
"1 False False False False \n",
"2 False True False False \n",
"3 False False False True \n",
"4 False True False True \n",
"\n",
" work_type_Self-employed work_type_children Residence_type_Urban \\\n",
"0 False False True \n",
"1 False True True \n",
"2 False False True \n",
"3 False False False \n",
"4 False False True \n",
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"0 False True \n",
"1 False False \n",
"2 False False \n",
"3 False False \n",
"4 True False \n",
"\n",
" smoking_status_smokes age_bin \n",
"0 False old \n",
"1 False young \n",
"2 False middle-aged \n",
"3 False young \n",
"4 False old \n"
]
}
],
"source": [
"# Определение числовых признаков для дискретизации\n",
"numerical_features = ['age']\n",
"\n",
"# Функция для дискретизации числовых признаков\n",
"def discretize_features(df, features, bins, labels):\n",
" for feature in features:\n",
" df[f'{feature}_bin'] = pd.cut(df[feature], bins=bins, labels=labels)\n",
" df.drop(columns=[feature], inplace=True)\n",
" return df\n",
"\n",
"# Заданные интервалы и метки\n",
"age_bins = [0, 25, 55, 100]\n",
"age_labels = [\"young\", \"middle-aged\", \"old\"]\n",
"\n",
"# Применение дискретизации к обучающей, контрольной и тестовой выборкам\n",
"X_train_encoded = discretize_features(X_train_encoded, numerical_features, bins=age_bins, labels=age_labels)\n",
"X_val_encoded = discretize_features(X_val_encoded, numerical_features, bins=age_bins, labels=age_labels)\n",
"X_test_encoded = discretize_features(X_test_encoded, numerical_features, bins=age_bins, labels=age_labels)\n",
"\n",
"print(X_train_encoded.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Применим ручной синтез признаков. Это создание новых признаков на основе существующих, учитывая экспертные знания и логику предметной области. К примеру, в этом случае можно создать признак, в котором вычисляется насколько уровень глюкозы отклоняется от среднего для возрастной группы пациента. Такой признак может быть полезен для выделения пациентов с нетипичными данными."
]
},
{
"cell_type": "code",
"execution_count": 449,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" id hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"0 16605 0 0 106.24 32.3 True \n",
"1 12015 0 0 99.87 25.2 True \n",
"2 26474 0 0 97.16 33.1 False \n",
"3 31143 0 0 107.52 41.6 False \n",
"4 2447 0 0 85.04 29.7 False \n",
"\n",
" gender_Other ever_married_Yes work_type_Never_worked work_type_Private \\\n",
"0 False True False True \n",
"1 False False False False \n",
"2 False True False False \n",
"3 False False False True \n",
"4 False True False True \n",
"\n",
" work_type_Self-employed work_type_children Residence_type_Urban \\\n",
"0 False False True \n",
"1 False True True \n",
"2 False False True \n",
"3 False False False \n",
"4 False False True \n",
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"0 False True \n",
"1 False False \n",
"2 False False \n",
"3 False False \n",
"4 True False \n",
"\n",
" smoking_status_smokes age_bin glucose_age_deviation \n",
"0 False old -27.642870 \n",
"1 False young 6.088032 \n",
"2 False middle-aged -6.217053 \n",
"3 False young 13.738032 \n",
"4 False old -48.842870 \n"
]
}
],
"source": [
"age_glucose_mean = X_train_encoded.groupby('age_bin', observed=False)['avg_glucose_level'].transform('mean')\n",
"X_train_encoded['glucose_age_deviation'] = X_train_encoded['avg_glucose_level'] - age_glucose_mean\n",
"\n",
"age_glucose_mean = X_val_encoded.groupby('age_bin', observed=False)['avg_glucose_level'].transform('mean')\n",
"X_val_encoded['glucose_age_deviation'] = X_val_encoded['avg_glucose_level'] - age_glucose_mean\n",
"\n",
"age_glucose_mean = X_test_encoded.groupby('age_bin', observed=False)['avg_glucose_level'].transform('mean')\n",
"X_test_encoded['glucose_age_deviation'] = X_test_encoded['avg_glucose_level'] - age_glucose_mean\n",
"\n",
"print(X_train_encoded.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Теперь используем масштабирование признаков, что позволяет привести все числовые признаки к одинаковым или очень похожим диапазонам значений либо распределениям. По результатам многочисленных исследований масштабирование признаков позволяет получить более качественную модель за счет снижения доминирования одних признаков над другими."
]
},
{
"cell_type": "code",
"execution_count": 450,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" id hypertension heart_disease avg_glucose_level bmi \\\n",
"0 16605 0 0 -0.244097 0.426328 \n",
"1 12015 0 0 -0.360110 -0.596170 \n",
"2 26474 0 0 -0.409465 0.541539 \n",
"3 31143 0 0 -0.220785 1.765656 \n",
"4 2447 0 0 -0.630199 0.051892 \n",
"\n",
" gender_Male gender_Other ever_married_Yes work_type_Never_worked \\\n",
"0 True False True False \n",
"1 True False False False \n",
"2 False False True False \n",
"3 False False False False \n",
"4 False False True False \n",
"\n",
" work_type_Private work_type_Self-employed work_type_children \\\n",
"0 True False False \n",
"1 False False True \n",
"2 False False False \n",
"3 True False False \n",
"4 True False False \n",
"\n",
" Residence_type_Urban smoking_status_formerly smoked \\\n",
"0 True False \n",
"1 True False \n",
"2 True False \n",
"3 False False \n",
"4 True True \n",
"\n",
" smoking_status_never smoked smoking_status_smokes age_bin \\\n",
"0 True False old \n",
"1 False False young \n",
"2 False False middle-aged \n",
"3 False False young \n",
"4 False False old \n",
"\n",
" glucose_age_deviation \n",
"0 -0.528807 \n",
"1 0.116464 \n",
"2 -0.118932 \n",
"3 0.262808 \n",
"4 -0.934362 \n"
]
}
],
"source": [
"# Пример масштабирования числовых признаков\n",
"numerical_features = ['avg_glucose_level', 'bmi', 'glucose_age_deviation']\n",
"\n",
"scaler = StandardScaler()\n",
"X_train_encoded[numerical_features] = scaler.fit_transform(X_train_encoded[numerical_features])\n",
"X_val_encoded[numerical_features] = scaler.transform(X_val_encoded[numerical_features])\n",
"X_test_encoded[numerical_features] = scaler.transform(X_test_encoded[numerical_features])\n",
"\n",
"print(X_train_encoded.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"И также попробуем сконструировать признаки, используя фреймворк Featuretools:"
]
},
{
"cell_type": "code",
"execution_count": 451,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" id hypertension heart_disease avg_glucose_level bmi \\\n",
"index \n",
"0 16605 0 0 -0.244097 0.426328 \n",
"1 12015 0 0 -0.360110 -0.596170 \n",
"2 26474 0 0 -0.409465 0.541539 \n",
"3 31143 0 0 -0.220785 1.765656 \n",
"4 2447 0 0 -0.630199 0.051892 \n",
"\n",
" gender_Male gender_Other ever_married_Yes work_type_Never_worked \\\n",
"index \n",
"0 True False True False \n",
"1 True False False False \n",
"2 False False True False \n",
"3 False False False False \n",
"4 False False True False \n",
"\n",
" work_type_Private work_type_Self-employed work_type_children \\\n",
"index \n",
"0 True False False \n",
"1 False False True \n",
"2 False False False \n",
"3 True False False \n",
"4 True False False \n",
"\n",
" Residence_type_Urban smoking_status_formerly smoked \\\n",
"index \n",
"0 True False \n",
"1 True False \n",
"2 True False \n",
"3 False False \n",
"4 True True \n",
"\n",
" smoking_status_never smoked smoking_status_smokes age_bin \\\n",
"index \n",
"0 True False old \n",
"1 False False young \n",
"2 False False middle-aged \n",
"3 False False young \n",
"4 False False old \n",
"\n",
" glucose_age_deviation \n",
"index \n",
"0 -0.528807 \n",
"1 0.116464 \n",
"2 -0.118932 \n",
"3 0.262808 \n",
"4 -0.934362 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Ilya\\Desktop\\AIM\\aimenv\\Lib\\site-packages\\woodwork\\type_sys\\utils.py:33: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
" pd.to_datetime(\n"
]
}
],
"source": [
"data = X_train_encoded.copy() # Используем предобработанные данные\n",
"\n",
"es = ft.EntitySet(id=\"patients\")\n",
"\n",
"es = es.add_dataframe(dataframe_name=\"strokes_data\", dataframe=data, index=\"index\", make_index=True)\n",
"\n",
"feature_matrix, feature_defs = ft.dfs(\n",
" entityset=es, \n",
" target_dataframe_name=\"strokes_data\",\n",
" max_depth=1\n",
")\n",
"\n",
"print(feature_matrix.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Оценим качество набора признаков.\n",
"\n",
"Представим основные оценки качества наборов признаков: \n",
"\n",
"* Предсказательная способность Метрики: RMSE, MAE, R²\n",
"\n",
" Методы: Обучение модели на обучающей выборке и оценка на контрольной и тестовой выборках.\n",
"\n",
"* Скорость вычисления \n",
"\n",
" Методы: Измерение времени выполнения генерации признаков и обучения модели.\n",
"\n",
"* Надежность \n",
"\n",
" Методы: Кросс-валидация, анализ чувствительности модели к изменениям в данных.\n",
"\n",
"* Корреляция \n",
"\n",
" Методы: Анализ корреляционной матрицы признаков, удаление мультиколлинеарных признаков.\n",
"\n",
"* Цельность \n",
"\n",
" Методы: Проверка логической связи между признаками и целевой переменной, интерпретация результатов модели."
]
},
{
"cell_type": "code",
"execution_count": 452,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Время обучения модели: 0.01 секунд\n",
"Среднеквадратичная ошибка: 0.41\n"
]
}
],
"source": [
"X_train_encoded = pd.get_dummies(X_train_encoded, drop_first=True)\n",
"X_val_encoded = pd.get_dummies(X_val_encoded, drop_first=True)\n",
"X_test_encoded = pd.get_dummies(X_test_encoded, drop_first=True)\n",
"\n",
"all_columns = X_train_encoded.columns\n",
"X_train_encoded = X_train_encoded.reindex(columns=all_columns, fill_value=0)\n",
"X_val_encoded = X_val_encoded.reindex(columns=all_columns, fill_value=0)\n",
"X_test_encoded = X_test_encoded.reindex(columns=all_columns, fill_value=0)\n",
"\n",
"# Обучение модели\n",
"model = LinearRegression()\n",
"\n",
"# Начинаем отсчет времени\n",
"start_time = time.time()\n",
"model.fit(X_train_encoded, y_train_resampled)\n",
"\n",
"# Время обучения модели\n",
"train_time = time.time() - start_time\n",
"\n",
"# Предсказания и оценка модели и вычисляем среднеквадратичную ошибку\n",
"predictions = model.predict(X_val_encoded)\n",
"mse = root_mean_squared_error(y_val_resampled, predictions)\n",
"\n",
"print(f'Время обучения модели: {train_time:.2f} секунд')\n",
"print(f'Среднеквадратичная ошибка: {mse:.2f}')"
]
},
{
"cell_type": "code",
"execution_count": 453,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"RMSE: 0.24109840514907446\n",
"R²: -0.06295721700021817\n",
"MAE: 0.10402478799739073 \n",
"\n",
"Кросс-валидация RMSE: 0.1197518340742331 \n",
"\n",
"Train RMSE: 0.037396456827854585\n",
"Train R²: 0.9944060200668896\n",
"Train MAE: 0.010727424749163881\n",
"\n"
]
}
],
"source": [
"# Выбор модели\n",
"model = RandomForestRegressor(random_state=42)\n",
"\n",
"# Обучение модели\n",
"model.fit(X_train_encoded, y_train_resampled)\n",
"\n",
"# Предсказание и оценка\n",
"y_pred = model.predict(X_test_encoded)\n",
"\n",
"rmse = root_mean_squared_error(y_test, y_pred)\n",
"r2 = r2_score(y_test, y_pred)\n",
"mae = mean_absolute_error(y_test, y_pred)\n",
"\n",
"print()\n",
"print(f\"RMSE: {rmse}\")\n",
"print(f\"R²: {r2}\")\n",
"print(f\"MAE: {mae} \\n\")\n",
"\n",
"# Кросс-валидация\n",
"scores = cross_val_score(model, X_train_encoded, y_train_resampled, cv=5, scoring='neg_mean_squared_error')\n",
"rmse_cv = (-scores.mean())**0.5\n",
"print(f\"Кросс-валидация RMSE: {rmse_cv} \\n\")\n",
"\n",
"# Проверка на переобучение\n",
"y_train_pred = model.predict(X_train_encoded)\n",
"\n",
"rmse_train = root_mean_squared_error(y_train_resampled, y_train_pred)\n",
"r2_train = r2_score(y_train_resampled, y_train_pred)\n",
"mae_train = mean_absolute_error(y_train_resampled, y_train_pred)\n",
"\n",
"print(f\"Train RMSE: {rmse_train}\")\n",
"print(f\"Train R²: {r2_train}\")\n",
"print(f\"Train MAE: {mae_train}\")\n",
"print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Можно заметить, что модель хорошо подстроилась под тренировочные данные (Низкий Train RMSE и высокое значение Train R²). Однако высокий RMSE и отрицательный R² на тестовом наборе свидетельствуют о том, что модель не обобщила зависимости и плохо предсказывает новые данные, поэтому можно сделать вывод о том, что получившийся набор признаков, к сожалению, далек от идеала. "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "aimenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}