AIM-PIbd-31-Rodionov-I-A/lab_3/lab3.ipynb

1111 lines
192 KiB
Plaintext
Raw Permalink Normal View History

2024-11-01 23:33:34 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Данные по инсультам\n",
"\n",
"Выведем информацию о столбцах датасета:"
]
},
{
"cell_type": "code",
"execution_count": 136,
2024-11-01 23:33:34 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['id', 'gender', 'age', 'hypertension', 'heart_disease', 'ever_married',\n",
" 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi',\n",
" 'smoking_status', 'stroke'],\n",
" dtype='object')\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>9046</td>\n",
" <td>Male</td>\n",
" <td>67.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>228.69</td>\n",
" <td>36.6</td>\n",
" <td>formerly smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>51676</td>\n",
" <td>Female</td>\n",
" <td>61.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>202.21</td>\n",
" <td>NaN</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>31112</td>\n",
" <td>Male</td>\n",
" <td>80.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>105.92</td>\n",
" <td>32.5</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>60182</td>\n",
" <td>Female</td>\n",
" <td>49.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>171.23</td>\n",
" <td>34.4</td>\n",
" <td>smokes</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1665</td>\n",
" <td>Female</td>\n",
" <td>79.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>174.12</td>\n",
" <td>24.0</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id gender age hypertension heart_disease ever_married \\\n",
"0 9046 Male 67.0 0 1 Yes \n",
"1 51676 Female 61.0 0 0 Yes \n",
"2 31112 Male 80.0 0 1 Yes \n",
"3 60182 Female 49.0 0 0 Yes \n",
"4 1665 Female 79.0 1 0 Yes \n",
"\n",
" work_type Residence_type avg_glucose_level bmi smoking_status \\\n",
"0 Private Urban 228.69 36.6 formerly smoked \n",
"1 Self-employed Rural 202.21 NaN never smoked \n",
"2 Private Rural 105.92 32.5 never smoked \n",
"3 Private Urban 171.23 34.4 smokes \n",
"4 Self-employed Rural 174.12 24.0 never smoked \n",
"\n",
" stroke \n",
"0 1 \n",
"1 1 \n",
"2 1 \n",
"3 1 \n",
"4 1 "
]
},
"execution_count": 136,
2024-11-01 23:33:34 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"from imblearn.over_sampling import RandomOverSampler\n",
"from sklearn.preprocessing import StandardScaler\n",
"import featuretools as ft\n",
"import time\n",
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.model_selection import cross_val_score\n",
2024-11-01 23:33:34 +04:00
"\n",
"df = pd.read_csv(\"..//..//static//csv//healthcare-dataset-stroke-data.csv\")\n",
"\n",
"print(df.columns)\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Определим бизнес цели и цели технического проекта.\n",
"\n",
"1. Улучшение диагностики и профилактики инсульта.\n",
" * Бизнес-цель: повышение точности прогнозирования риска инсульта среди пациентов для более раннего лечебного вмешательства. Определение основных факторов риска для более целенаправленного подхода в медицинском обслуживании.\n",
" * Цель технического проекта: разработка статистической модели, которая решает задачу классификации и предсказывает возможность возникновения инсульта у пациентов на основе имеющихся данных (возраст, гипертония, заболевания сердца и пр.), с целью выявления групп риска. Внедрение этой модели в систему поддержки принятия медицинских решений для врачей.\n",
"2. Снижение расходов на лечение инсультов.\n",
" * Бизнес-цель: снижение затрат на лечение инсульта путем более эффективного распределения медицинских ресурсов и направленных профилактических мер.\n",
" * Цель технического проекта: создание системы оценки индивидуального риска инсульта для пациентов, что позволит медучреждениям проводить профилактические меры среди целевых групп, сокращая расходы на лечение.\n",
"\n",
"### И теперь проверим датасет на пустые значения:"
]
},
{
"cell_type": "code",
"execution_count": 137,
2024-11-01 23:33:34 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"id 0\n",
"gender 0\n",
"age 0\n",
"hypertension 0\n",
"heart_disease 0\n",
"ever_married 0\n",
"work_type 0\n",
"Residence_type 0\n",
"avg_glucose_level 0\n",
"bmi 201\n",
"smoking_status 0\n",
"stroke 0\n",
"dtype: int64\n",
"\n",
"id False\n",
"gender False\n",
"age False\n",
"hypertension False\n",
"heart_disease False\n",
"ever_married False\n",
"work_type False\n",
"Residence_type False\n",
"avg_glucose_level False\n",
"bmi True\n",
"smoking_status False\n",
"stroke False\n",
"dtype: bool\n",
"\n",
"bmi процент пустых значений: %3.93\n"
]
}
],
"source": [
"# Количество пустых значений признаков\n",
"print(df.isnull().sum())\n",
"\n",
"print()\n",
"\n",
"# Есть ли пустые значения признаков\n",
"print(df.isnull().any())\n",
"\n",
"print()\n",
"\n",
"# Процент пустых значений признаков\n",
"for i in df.columns:\n",
" null_rate = df[i].isnull().sum() / len(df) * 100\n",
" if null_rate > 0:\n",
" print(f\"{i} процент пустых значений: %{null_rate:.2f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В столбце bmi можно заметить пустые значение. Заменим их на медиану:"
]
},
{
"cell_type": "code",
"execution_count": 138,
2024-11-01 23:33:34 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Количество пустых значений в каждом столбце после замены:\n",
"id 0\n",
"gender 0\n",
"age 0\n",
"hypertension 0\n",
"heart_disease 0\n",
"ever_married 0\n",
"work_type 0\n",
"Residence_type 0\n",
"avg_glucose_level 0\n",
"bmi 0\n",
"smoking_status 0\n",
"stroke 0\n",
"dtype: int64\n"
]
}
],
"source": [
"# Замена значений\n",
"df[\"bmi\"] = df[\"bmi\"].fillna(df[\"bmi\"].median())\n",
"\n",
"# Проверка на пропущенные значения после замены\n",
"missing_values_after_drop = df.isnull().sum()\n",
"\n",
"# Вывод результатов после замены\n",
"print(\"\\nКоличество пустых значений в каждом столбце после замены:\")\n",
"print(missing_values_after_drop)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Удалим из датафрейма столбец id, потому что нет смысла учитывать его при предсказании:"
]
},
{
"cell_type": "code",
"execution_count": 139,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['gender', 'age', 'hypertension', 'heart_disease', 'ever_married',\n",
" 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi',\n",
" 'smoking_status', 'stroke'],\n",
" dtype='object')\n"
]
}
],
"source": [
"df = df.drop('id', axis=1)\n",
"print(df.columns)"
]
},
2024-11-01 23:33:34 +04:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Можно перейти к созданию выборок"
]
},
{
"cell_type": "code",
"execution_count": 140,
2024-11-01 23:33:34 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Размер обучающей выборки: (2503, 10)\n",
"Размер контрольной выборки: (1074, 10)\n",
"Размер тестовой выборки: (1533, 10)\n"
2024-11-01 23:33:34 +04:00
]
}
],
"source": [
"# Разделение данных на признаки (X) и целевую переменную (y)\n",
"# В данном случае мы хотим предсказать 'stroke'\n",
"X = df.drop(columns=['stroke'])\n",
"y = df['stroke']\n",
"\n",
"# Разбиение данных на обучающую и тестовую выборки\n",
"# Сначала разделим на обучающую и тестовую\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)\n",
"\n",
"# Затем разделим обучающую выборку на обучающую и контрольную\n",
"X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.3)\n",
"\n",
"# Проверка размеров выборок\n",
"print(\"Размер обучающей выборки:\", X_train.shape)\n",
"print(\"Размер контрольной выборки:\", X_val.shape)\n",
"print(\"Размер тестовой выборки:\", X_test.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Оценим сбалансированность выборок:"
]
},
{
"cell_type": "code",
"execution_count": 141,
2024-11-01 23:33:34 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Распределение классов в обучающей выборке:\n",
"stroke\n",
"0 0.95006\n",
"1 0.04994\n",
2024-11-01 23:33:34 +04:00
"Name: proportion, dtype: float64\n",
"\n",
"Распределение классов в контрольной выборке:\n",
"stroke\n",
"0 0.951583\n",
"1 0.048417\n",
2024-11-01 23:33:34 +04:00
"Name: proportion, dtype: float64\n",
"\n",
"Распределение классов в тестовой выборке:\n",
"stroke\n",
"0 0.953033\n",
"1 0.046967\n",
2024-11-01 23:33:34 +04:00
"Name: proportion, dtype: float64\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABboAAAHyCAYAAAAtJXgGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABpeklEQVR4nO3deZyNdf/H8ffMmH0sMattxj52Gktosg2DEeqOLHeGCoUK3Soqg5ZJSoSyFC10J4ruImuUNBFSRBIjSwxjG+sMc76/Pzzm/BznzBhLzlx6PR+PeTzmfM/3uq7Pdc7M+V7nfa7zvTyMMUYAAAAAAAAAAFiUp7sLAAAAAAAAAADgehB0AwAAAAAAAAAsjaAbAAAAAAAAAGBpBN0AAAAAAAAAAEsj6AYAAAAAAAAAWBpBNwAAAAAAAADA0gi6AQAAAAAAAACWRtANAAAAAAAAALA0gm4AAAAAuAY2m03p6enatWuXu0sBAAD4xyPoBgAAAIB8OnjwoAYNGqTIyEj5+PgoJCRE1apVU0ZGhrtLAwAA+Ecr5O4CAAAAbrT33ntPvXv3tt/29fVV2bJl1bp1az3//PMKCwtzY3UArOqPP/5Q8+bNdf78eT3++OO6/fbbVahQIfn7+yswMNDd5QEAAPyjEXQDAIBb1ujRo1WuXDmdO3dO3333nd5++20tWrRIW7ZsUUBAgLvLA2Ax/fr1k4+Pj3744QeVKlXK3eUAAADgEgTdAADgltW2bVvVq1dPkvTwww+rRIkSGjdunD7//HN169bNzdUBsJINGzbo66+/1tKlSwm5AQAACiDm6AYAAP8YLVq0kCSlpqZKko4ePar//Oc/qlmzpoKCglSkSBG1bdtWP//8s9Oy586d08iRI1W5cmX5+fkpIiJC9957r3bu3ClJ2r17tzw8PHL9adasmX1dq1atkoeHh+bMmaPhw4crPDxcgYGB6tChg/bu3eu07bVr16pNmzYqWrSoAgIC1LRpU61Zs8blPjZr1szl9keOHOnUd9asWYqJiZG/v7+KFy+url27utx+Xvt2KZvNpvHjx6t69ery8/NTWFiY+vXrp2PHjjn0i4qKUvv27Z22M3DgQKd1uqp97NixTo+pJGVmZiopKUkVK1aUr6+vypQpo6eeekqZmZkuH6tLXf64BQcHKyEhQVu2bMnXsjVq1NCGDRvUuHFj+fv7q1y5cpoyZYpDv6ysLI0YMUIxMTEqWrSoAgMDFRsbq5UrVzr02759u1q0aKHw8HD7fjzyyCM6evSo07Z79ep1xee7V69eioqKclhu79698vf3l4eHh3bv3i3p/5/n9957z6HvyJEjXT4vAwcOdKqnffv2DtvKWedrr72Wy6PnvP6ZM2fKw8NDM2bMcOj38ssvy8PDQ4sWLcp1XdLFv6+cx8HT01Ph4eG6//77tWfPnuuq64cffpCfn5927typ6tWry9fXV+Hh4erXr5/L52bu3Ln2/6/g4GD9+9//1v79+x369OrVS0FBQdq1a5fi4+MVGBiokiVLavTo0TLGONV76XNz8uRJxcTEqFy5cjpw4IC9/bXXXlPjxo1VokQJ+fv7KyYmRvPmzXPY7vU+xgAAAAURZ3QDAIB/jJxQukSJEpKkXbt2acGCBercubPKlSuntLQ0TZ06VU2bNtXWrVtVsmRJSVJ2drbat2+vFStWqGvXrnriiSd08uRJLVu2TFu2bFGFChXs2+jWrZvatWvnsN1hw4a5rOell16Sh4eHnn76aR06dEjjx49XXFycNm3aJH9/f0nS119/rbZt2yomJkZJSUny9PTUzJkz1aJFC61evVoNGjRwWm/p0qWVnJwsSTp16pQeffRRl9t+/vnn1aVLFz388MM6fPiwJk6cqLvuuks//fSTihUr5rRM3759FRsbK0n67LPPNH/+fIf7+/XrZ58f/fHHH1dqaqomTZqkn376SWvWrJG3t7fLx+FqHD9+3L5vl7LZbOrQoYO+++479e3bV1WrVtXmzZv1xhtv6Pfff9eCBQuuuO7o6Gg9++yzMsZo586dGjdunNq1a+cQkObm2LFjateunbp06aJu3brpk08+0aOPPiofHx89+OCDkqSMjAy988476tatm/r06aOTJ0/q3XffVXx8vNatW6c6depIkk6fPq3SpUvr7rvvVpEiRbRlyxZNnjxZ+/fv1xdffOG07eDgYL3xxhv22w888MAV6x0xYoTOnTt3xX7u0Lt3b3322WcaMmSIWrVqpTJlymjz5s0aNWqUHnroIaf/L1diY2PVt29f2Ww2bdmyRePHj9dff/2l1atXX3NdR44c0blz5/Too4+qRYsWeuSRR7Rz505NnjxZa9eu1dq1a+Xr6yvp/68TUL9+fSUnJystLU0TJkzQmjVrnP6/srOz1aZNG91xxx169dVXtXjxYiUlJenChQsaPXq0y1rOnz+vf/3rX9qzZ4/WrFmjiIgI+30TJkxQhw4d1KNHD2VlZenjjz9W586d9eWXXyohIeGGPcYAAAAFjgEAALjFzJw500gyy5cvN4cPHzZ79+41H3/8sSlRooTx9/c3+/btM8YYc+7cOZOdne2wbGpqqvH19TWjR4+2t82YMcNIMuPGjXPals1msy8nyYwdO9apT/Xq1U3Tpk3tt1euXGkkmVKlSpmMjAx7+yeffGIkmQkTJtjXXalSJRMfH2/fjjHGnDlzxpQrV860atXKaVuNGzc2NWrUsN8+fPiwkWSSkpLsbbt37zZeXl7mpZdeclh28+bNplChQk7tO3bsMJLM+++/b29LSkoylx5Krl692kgys2fPdlh28eLFTu2RkZEmISHBqfYBAwaYyw9PL6/9qaeeMqGhoSYmJsbhMf3www+Np6enWb16tcPyU6ZMMZLMmjVrnLZ3qaZNmzqszxhjhg8fbiSZQ4cOXXFZSeb111+3t2VmZpo6deqY0NBQk5WVZYwx5sKFCyYzM9Nh2WPHjpmwsDDz4IMP5rmN/v37m6CgIKf2Hj16mHLlyjm0Xf6YJSYmmsjISPvtLVu2GE9PT9O2bVsjyaSmphpjjPnzzz+NJDNjxgyH9V3+XOdsY8CAAU71JCQkOGwrr/+LvNZ/4MABU7x4cdOqVSuTmZlp6tata8qWLWtOnDiR63pyREZGmsTERIe27t27m4CAgOuqK+d2y5YtzYULF+ztOa83EydONMYYk5WVZUJDQ02NGjXM2bNn7f2+/PJLI8mMGDHC3paYmGgkmccee8zeZrPZTEJCgvHx8TGHDx92qHfmzJnGZrOZHj16mICAALN27Vqnus+cOeNwOysry9SoUcO0aNHCof16HmMAAICCiKlLAADALSsuLk4hISEqU6aMunbtqqCgIM2fP98+v66vr688PS8eDmVnZ+vIkSMKCgpSlSpVtHHjRvt6Pv30UwUHB+uxxx5z2sblUzpcjZ49e6pw4cL22/fdd58iIiLs0wZs2rRJO3bsUPfu3XXkyBGlp6crPT1dp0+fVsuWLfXtt9/KZrM5rPPcuXPy8/PLc7ufffaZbDabunTpYl9nenq6wsPDValSJaepNLKysiTJfraqK3PnzlXRokXVqlUrh3XGxMQoKCjIaZ3nz5936Jeenn7FM4z379+viRMn6vnnn1dQUJDT9qtWraro6GiHdeZMV3P59l3Jqenw4cNKSUnR/PnzVatWLQUHB19x2UKFCqlfv3722z4+PurXr58OHTqkDRs2SJK8vLzk4+Mj6eIZ6EePHtWFCxdUr149h7+3HCdOnFBaWppWrFihhQsX6q677nLqk5WVlefz4sqwYcN0++23q3Pnzg7tISEhkqR9+/blaz3nzp1zeg7Pnz/vsu+ZM2eUnp6uY8eOOUzJkZvw8HBNnjxZy5YtU2xsrDZt2qQZM2aoSJEi+aotMzNT6enpOnTokJYtW6avv/5aLVu2vO66JGnIkCHy8vKy337ggQcUFhamhQsXSpLWr1+vQ4cOqX///g7/iwkJCYqOjrb3u9Sl08DkTAuTlZWl5cuXO/UdOnSoZs+erU8++cTlNzpyvg0iXfymwYkTJxQbG+v0N3a9jzEAAEBBw9QlAADgljV
2024-11-01 23:33:34 +04:00
"text/plain": [
"<Figure size 1800x500 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Функция для анализа сбалансированности\n",
"def analyze_balance(y_train, y_val, y_test, y_name):\n",
" # Распределение классов\n",
" print(\"Распределение классов в обучающей выборке:\")\n",
" print(y_train.value_counts(normalize=True))\n",
" \n",
" print(\"\\nРаспределение классов в контрольной выборке:\")\n",
" print(y_val.value_counts(normalize=True))\n",
" \n",
" print(\"\\nРаспределение классов в тестовой выборке:\")\n",
" print(y_test.value_counts(normalize=True))\n",
"\n",
" # Создание фигуры и осей для трех столбчатых диаграмм\n",
" fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True)\n",
" fig.suptitle('Распределение в различных выборках')\n",
"\n",
" # Обучающая выборка\n",
" sns.barplot(x=y_train.value_counts().index, y=y_train.value_counts(normalize=True), ax=axes[0])\n",
" axes[0].set_title('Обучающая выборка')\n",
" axes[0].set_xlabel(y_name)\n",
" axes[0].set_ylabel('Доля')\n",
"\n",
" # Контрольная выборка\n",
" sns.barplot(x=y_val.value_counts().index, y=y_val.value_counts(normalize=True), ax=axes[1])\n",
" axes[1].set_title('Контрольная выборка')\n",
" axes[1].set_xlabel(y_name)\n",
"\n",
" # Тестовая выборка\n",
" sns.barplot(x=y_test.value_counts().index, y=y_test.value_counts(normalize=True), ax=axes[2])\n",
" axes[2].set_title('Тестовая выборка')\n",
" axes[2].set_xlabel(y_name)\n",
"\n",
" plt.show()\n",
"\n",
"analyze_balance(y_train, y_val, y_test, 'stroke')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Легко заметить, что выборки несбалансированны. Необходимо сбалансировать обучающую и контрольную выборки, чтобы получить лучшие результаты при обучении модели. Для балансировки применим RandomOverSampler:"
]
},
{
"cell_type": "code",
"execution_count": 142,
2024-11-01 23:33:34 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Распределение классов в обучающей выборке:\n",
"stroke\n",
"0 0.5\n",
"1 0.5\n",
"Name: proportion, dtype: float64\n",
"\n",
"Распределение классов в контрольной выборке:\n",
"stroke\n",
"0 0.5\n",
"1 0.5\n",
"Name: proportion, dtype: float64\n",
"\n",
"Распределение классов в тестовой выборке:\n",
"stroke\n",
"0 0.953033\n",
"1 0.046967\n",
2024-11-01 23:33:34 +04:00
"Name: proportion, dtype: float64\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABboAAAHyCAYAAAAtJXgGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABpmElEQVR4nO3deZyNdf/H8ffMmH0MMattxj52Gktosg2DEeqOLHeGCoUK3Soqg5ZJSoSyFC10J4ruImuUNBFSRBIjSwxjG+sMc76/Pzzm/BznDGPJmUuv5+Mxj8ec7/le1/W5zplzvtd5z3W+l4cxxggAAAAAAAAAAIvydHcBAAAAAAAAAABcD4JuAAAAAAAAAIClEXQDAAAAAAAAACyNoBsAAAAAAAAAYGkE3QAAAAAAAAAASyPoBgAAAAAAAABYGkE3AAAAAAAAAMDSCLoBAAAAAAAAAJZG0A0AAAAA18BmsykjI0M7d+50dykAAAD/eATdAAAAAJBPBw4c0MCBAxUVFSUfHx+FhoaqatWqyszMdHdpAAAA/2iF3F0AAADAjfbee++pV69e9tu+vr4qU6aMWrVqpeeff17h4eFurA6AVf3xxx9q1qyZzp07p8cff1y33367ChUqJH9/fwUGBrq7PAAAgH80gm4AAHDLGjVqlMqWLauzZ8/qu+++09tvv62FCxdq8+bNCggIcHd5ACymb9++8vHx0Q8//KCSJUu6uxwAAABchKAbAADcstq0aaO6detKkh5++GEVL15cY8eO1eeff66uXbu6uToAVrJ+/Xp9/fXXWrJkCSE3AABAAcQc3QAA4B+jefPmkqS0tDRJ0pEjR/Sf//xHNWrUUFBQkIKDg9WmTRv9/PPPTsuePXtWI0aMUKVKleTn56fIyEjde++92rFjhyRp165d8vDwyPOnadOm9nWtXLlSHh4emj17toYNG6aIiAgFBgaqffv22rNnj9O216xZo9atW6tIkSIKCAhQkyZNtHr1apf72LRpU5fbHzFihFPfmTNnKjY2Vv7+/ipWrJi6dOnicvuX27eL2Ww2jRs3TtWqVZOfn5/Cw8PVt29fHT161KFfdHS02rVr57SdAQMGOK3TVe1jxoxxekwlKSsrS8nJyapQoYJ8fX1VunRpPfXUU8rKynL5WF3s0sctJCREiYmJ2rx5c76WrV69utavX69GjRrJ399fZcuW1eTJkx36ZWdna/jw4YqNjVWRIkUUGBiouLg4rVixwqHftm3b1Lx5c0VERNj345FHHtGRI0ectt2zZ88rPt89e/ZUdHS0w3J79uyRv7+/PDw8tGvXLkn//zy/9957Dn1HjBjh8nkZMGCAUz3t2rVz2FbuOl977bU8Hj3n9c+YMUMeHh6aPn26Q7+XX35ZHh4eWrhwYZ7rki78feU+Dp6enoqIiND999+v3bt3X1ddP/zwg/z8/LRjxw5Vq1ZNvr6+ioiIUN++fV0+N3PmzLG/vkJCQvTvf/9b+/btc+jTs2dPBQUFaefOnUpISFBgYKBKlCihUaNGyRjjVO/Fz82JEycUGxursmXLav/+/fb21157TY0aNVLx4sXl7++v2NhYzZ0712G71/sYAwAAFESc0Q0AAP4xckPp4sWLS5J27typ+fPnq1OnTipbtqzS09M1ZcoUNWnSRFu2bFGJEiUkSTk5OWrXrp2WL1+uLl266IknntCJEye0dOlSbd68WeXLl7dvo2vXrmrbtq3DdocOHeqynpdeekkeHh56+umndfDgQY0bN07x8fHauHGj/P39JUlff/212rRpo9jYWCUnJ8vT01MzZsxQ8+bNtWrVKtWvX99pvaVKlVJKSook6eTJk3r00Uddbvv5559X586d9fDDD+vQoUOaMGGC7rrrLv30008qWrSo0zJ9+vRRXFycJOmzzz7TvHnzHO7v27evfX70xx9/XGlpaZo4caJ++uknrV69Wt7e3i4fh6tx7Ngx+75dzGazqX379vruu+/Up08fValSRZs2bdIbb7yh33//XfPnz7/iumNiYvTss8/KGKMdO3Zo7Nixatu2rUNAmpejR4+qbdu26ty5s7p27apPPvlEjz76qHx8fPTggw9KkjIzM/XOO++oa9eu6t27t06cOKF3331XCQkJWrt2rWrXri1JOnXqlEqVKqW7775bwcHB2rx5syZNmqR9+/bpiy++cNp2SEiI3njjDfvtBx544Ir1Dh8+XGfPnr1iP3fo1auXPvvsMw0ePFgtW7ZU6dKltWnTJo0cOVIPPfSQ0+vLlbi4OPXp00c2m02bN2/WuHHj9Ndff2nVqlXXXNfhw4d19uxZPfroo2revLkeeeQR7dixQ5MmTdKaNWu0Zs0a+fr6Svr/6wTUq1dPKSkpSk9P1/jx47V69Wqn11dOTo5at26tO+64Q6+++qoWLVqk5ORknT9/XqNGjXJZy7lz5/Svf/1Lu3fv1urVqxUZGWm/b/z48Wrfvr26d++u7Oxsffzxx+rUqZO+/PJLJSYm3rDHGAAAoMAxAAAAt5gZM2YYSWbZsmXm0KFDZs+ePebjjz82xYsXN/7+/mbv3r3GGGPOnj1rcnJyHJZNS0szvr6+ZtSoUfa26dOnG0lm7NixTtuy2Wz25SSZMWPGOPWpVq2aadKkif32ihUrjCRTsmRJk5mZaW//5JNPjCQzfvx4+7orVqxoEhIS7NsxxpjTp0+bsmXLmpYtWzptq1GjRqZ69er224cOHTKSTHJysr1t165dxsvLy7z00ksOy27atMkUKlTIqX379u1Gknn//fftbcnJyebiQ8lVq1YZSWbWrFkOyy5atMipPSoqyiQmJjrV3r9/f3Pp4emltT/11FMmLCzMxMbGOjymH374ofH09DSrVq1yWH7y5MlGklm9erXT9i7WpEkTh/UZY8ywYcOMJHPw4MErLivJvP766/a2rKwsU7t2bRMWFmays7ONMcacP3/eZGVlOSx79OhREx4ebh588MHLbqNfv34mKCjIqb179+6mbNmyDm2XPmZJSUkmKirKfnvz5s3G09PTtGnTxkgyaWlpxhhj/vzzTyPJTJ8+3WF9lz7Xudvo37+/Uz2JiYkO27rc6+Jy69+/f78pVqyYadmypcnKyjJ16tQxZcqUMcePH89zPbmioqJMUlKSQ1u3bt1MQEDAddWVe7tFixbm/Pnz9vbc95sJEyYYY4zJzs42YWFhpnr16ubMmTP2fl9++aWRZIYPH25vS0pKMpLMY489Zm+z2WwmMTHR+Pj4mEOHDjnUO2PGDGOz2Uz37t1NQECAWbNmjVPdp0+fdridnZ1tqlevbpo3b+7Qfj2PMQAAQEHE1CUAAOCWFR8fr9DQUJUuXVpdunRRUFCQ5s2bZ59f19fXV56eFw6HcnJydPjwYQUFBaly5crasGGDfT2ffvqpQkJC9Nhjjzlt49IpHa5Gjx49VLhwYfvt++67T5GRkfZpAzZu3Kjt27erW7duOnz4sDIyMpSRkaFTp06pRYsW+vbbb2Wz2RzWefbsWfn5+V12u5999plsNps6d+5sX2dGRoYiIiJUsWJFp6k0srOzJcl+tqorc+bMUZEiRdSyZUuHdcbGxiooKMhpnefOnXPol5GRccUzjPft26cJEybo+eefV1BQkNP2q1SpopiYGId15k5Xc+n2Xcmt6dChQ0pNTdW8efNUs2ZNhYSEXHHZQoUKqW/fvvbbPj4+6tu3rw4ePKj169dLkry8vOTj4yPpwhnoR44c0fnz51W3bl2Hv7dcx48fV3p6upYvX64FCxborrvucuqTnZ192efFlaFDh+r2229Xp06dHNpDQ0MlSXv37s3Xes6ePev0HJ47d85l39OnTysjI0NHjx51mJIjLxEREZo0aZKWLl2quLg4bdy4UdOnT1dwcHC+asvKylJGRoYOHjyopUuX6uuvv1aLFi2uuy5JGjx4sLy8vOy3H3jgAYWHh2vBggWSpHXr1ungwYPq16+fw2sxMTFRMTEx9n4Xu3gamNxpYbKzs7Vs2TKnvkOGDNGsWbP0ySefuPxGR+63QaQL3zQ4fvy44uLinP7GrvcxBgAAKGiYugQAANyyJk2
2024-11-01 23:33:34 +04:00
"text/plain": [
"<Figure size 1800x500 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ros = RandomOverSampler(random_state=42)\n",
"\n",
"# Применение RandomOverSampler для балансировки выборок\n",
"X_train_resampled, y_train_resampled = ros.fit_resample(X_train, y_train)\n",
"X_val_resampled, y_val_resampled = ros.fit_resample(X_val, y_val)\n",
"\n",
"# Проверка сбалансированности после RandomOverSampler\n",
"analyze_balance(y_train_resampled, y_val_resampled, y_test, 'stroke')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выборки сбалансированы.\n",
"\n",
"### Перейдем к конструированию признаков\n",
"\n",
"Для начала применим унитарное кодирование категориальных признаков (one-hot encoding), переведя их в бинарные вектора:"
]
},
{
"cell_type": "code",
"execution_count": 143,
2024-11-01 23:33:34 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" age hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"0 78.0 0 0 137.74 34.9 False \n",
"1 58.0 0 0 99.83 36.3 True \n",
"2 77.0 0 0 59.91 18.3 False \n",
"3 80.0 1 1 175.29 31.5 True \n",
"4 58.0 1 0 59.52 33.2 False \n",
2024-11-01 23:33:34 +04:00
"\n",
" gender_Other ever_married_Yes work_type_Never_worked work_type_Private \\\n",
"0 False False False False \n",
"1 False True False False \n",
"2 False True False False \n",
"3 False True False True \n",
"4 False True False False \n",
2024-11-01 23:33:34 +04:00
"\n",
" work_type_Self-employed work_type_children Residence_type_Urban \\\n",
"0 True False True \n",
"1 True False False \n",
"2 True False False \n",
"3 False False True \n",
"4 False False False \n",
2024-11-01 23:33:34 +04:00
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"0 True False \n",
"1 False False \n",
"2 False True \n",
"3 True False \n",
"4 False True \n",
2024-11-01 23:33:34 +04:00
"\n",
" smoking_status_smokes \n",
"0 False \n",
"1 True \n",
"2 False \n",
"3 False \n",
"4 False \n"
2024-11-01 23:33:34 +04:00
]
}
],
"source": [
"# Определение категориальных признаков\n",
"categorical_features = ['gender', 'ever_married', 'work_type', 'Residence_type', 'smoking_status']\n",
"\n",
"# Применение one-hot encoding к обучающей выборке\n",
"X_train_encoded = pd.get_dummies(X_train_resampled, columns=categorical_features, drop_first=True)\n",
"\n",
"# Применение one-hot encoding к контрольной выборке\n",
"X_val_encoded = pd.get_dummies(X_val_resampled, columns=categorical_features, drop_first=True)\n",
"\n",
"# Применение one-hot encoding к тестовой выборке\n",
"X_test_encoded = pd.get_dummies(X_test, columns=categorical_features, drop_first=True)\n",
"\n",
"print(X_train_encoded.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Далее к числовым признакам, а именно к колонке age, применим дискретизацию (позволяет преобразовать данные из числового представления в категориальное):"
]
},
{
"cell_type": "code",
"execution_count": 144,
2024-11-01 23:33:34 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"0 0 0 137.74 34.9 False \n",
"1 0 0 99.83 36.3 True \n",
"2 0 0 59.91 18.3 False \n",
"3 1 1 175.29 31.5 True \n",
"4 1 0 59.52 33.2 False \n",
2024-11-01 23:33:34 +04:00
"\n",
" gender_Other ever_married_Yes work_type_Never_worked work_type_Private \\\n",
"0 False False False False \n",
"1 False True False False \n",
2024-11-01 23:33:34 +04:00
"2 False True False False \n",
"3 False True False True \n",
"4 False True False False \n",
2024-11-01 23:33:34 +04:00
"\n",
" work_type_Self-employed work_type_children Residence_type_Urban \\\n",
"0 True False True \n",
"1 True False False \n",
"2 True False False \n",
"3 False False True \n",
"4 False False False \n",
2024-11-01 23:33:34 +04:00
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"0 True False \n",
2024-11-01 23:33:34 +04:00
"1 False False \n",
"2 False True \n",
"3 True False \n",
"4 False True \n",
2024-11-01 23:33:34 +04:00
"\n",
" smoking_status_smokes age_bin \n",
"0 False old \n",
"1 True old \n",
"2 False old \n",
"3 False old \n",
"4 False old \n"
2024-11-01 23:33:34 +04:00
]
}
],
"source": [
"# Определение числовых признаков для дискретизации\n",
"numerical_features = ['age']\n",
"\n",
"# Функция для дискретизации числовых признаков\n",
"def discretize_features(df, features, bins, labels):\n",
" for feature in features:\n",
" df[f'{feature}_bin'] = pd.cut(df[feature], bins=bins, labels=labels)\n",
" df.drop(columns=[feature], inplace=True)\n",
" return df\n",
"\n",
"# Заданные интервалы и метки\n",
"age_bins = [0, 25, 55, 100]\n",
"age_labels = [\"young\", \"middle-aged\", \"old\"]\n",
"\n",
"# Применение дискретизации к обучающей, контрольной и тестовой выборкам\n",
"X_train_encoded = discretize_features(X_train_encoded, numerical_features, bins=age_bins, labels=age_labels)\n",
"X_val_encoded = discretize_features(X_val_encoded, numerical_features, bins=age_bins, labels=age_labels)\n",
"X_test_encoded = discretize_features(X_test_encoded, numerical_features, bins=age_bins, labels=age_labels)\n",
"\n",
"print(X_train_encoded.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Применим ручной синтез признаков. Это создание новых признаков на основе существующих, учитывая экспертные знания и логику предметной области. К примеру, в этом случае можно создать признак, в котором вычисляется насколько уровень глюкозы отклоняется от среднего для возрастной группы пациента. Такой признак может быть полезен для выделения пациентов с нетипичными данными."
]
},
{
"cell_type": "code",
"execution_count": 145,
2024-11-01 23:33:34 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"0 0 0 137.74 34.9 False \n",
"1 0 0 99.83 36.3 True \n",
"2 0 0 59.91 18.3 False \n",
"3 1 1 175.29 31.5 True \n",
"4 1 0 59.52 33.2 False \n",
2024-11-01 23:33:34 +04:00
"\n",
" gender_Other ever_married_Yes work_type_Never_worked work_type_Private \\\n",
"0 False False False False \n",
"1 False True False False \n",
2024-11-01 23:33:34 +04:00
"2 False True False False \n",
"3 False True False True \n",
"4 False True False False \n",
2024-11-01 23:33:34 +04:00
"\n",
" work_type_Self-employed work_type_children Residence_type_Urban \\\n",
"0 True False True \n",
"1 True False False \n",
"2 True False False \n",
"3 False False True \n",
"4 False False False \n",
2024-11-01 23:33:34 +04:00
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"0 True False \n",
2024-11-01 23:33:34 +04:00
"1 False False \n",
"2 False True \n",
"3 True False \n",
"4 False True \n",
2024-11-01 23:33:34 +04:00
"\n",
" smoking_status_smokes age_bin glucose_age_deviation \n",
"0 False old 7.343213 \n",
"1 True old -30.566787 \n",
"2 False old -70.486787 \n",
"3 False old 44.893213 \n",
"4 False old -70.876787 \n"
2024-11-01 23:33:34 +04:00
]
}
],
"source": [
"age_glucose_mean = X_train_encoded.groupby('age_bin', observed=False)['avg_glucose_level'].transform('mean')\n",
"X_train_encoded['glucose_age_deviation'] = X_train_encoded['avg_glucose_level'] - age_glucose_mean\n",
"\n",
"age_glucose_mean = X_val_encoded.groupby('age_bin', observed=False)['avg_glucose_level'].transform('mean')\n",
"X_val_encoded['glucose_age_deviation'] = X_val_encoded['avg_glucose_level'] - age_glucose_mean\n",
"\n",
"age_glucose_mean = X_test_encoded.groupby('age_bin', observed=False)['avg_glucose_level'].transform('mean')\n",
"X_test_encoded['glucose_age_deviation'] = X_test_encoded['avg_glucose_level'] - age_glucose_mean\n",
"\n",
"print(X_train_encoded.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Теперь используем масштабирование признаков, что позволяет привести все числовые признаки к одинаковым или очень похожим диапазонам значений либо распределениям. По результатам многочисленных исследований масштабирование признаков позволяет получить более качественную модель за счет снижения доминирования одних признаков над другими."
]
},
{
"cell_type": "code",
"execution_count": 146,
2024-11-01 23:33:34 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"0 0 0 0.366350 0.716465 False \n",
"1 0 0 -0.312101 0.912920 True \n",
"2 0 0 -1.026524 -1.612927 False \n",
"3 1 1 1.038358 0.239361 True \n",
"4 1 0 -1.033504 0.477913 False \n",
2024-11-01 23:33:34 +04:00
"\n",
" gender_Other ever_married_Yes work_type_Never_worked work_type_Private \\\n",
"0 False False False False \n",
"1 False True False False \n",
"2 False True False False \n",
"3 False True False True \n",
"4 False True False False \n",
2024-11-01 23:33:34 +04:00
"\n",
" work_type_Self-employed work_type_children Residence_type_Urban \\\n",
"0 True False True \n",
"1 True False False \n",
"2 True False False \n",
"3 False False True \n",
"4 False False False \n",
2024-11-01 23:33:34 +04:00
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"0 True False \n",
"1 False False \n",
"2 False True \n",
"3 True False \n",
"4 False True \n",
2024-11-01 23:33:34 +04:00
"\n",
" smoking_status_smokes age_bin glucose_age_deviation \n",
"0 False old 0.136565 \n",
"1 True old -0.568466 \n",
"2 False old -1.310877 \n",
"3 False old 0.834901 \n",
"4 False old -1.318130 \n"
2024-11-01 23:33:34 +04:00
]
}
],
"source": [
"# Пример масштабирования числовых признаков\n",
"numerical_features = ['avg_glucose_level', 'bmi', 'glucose_age_deviation']\n",
"\n",
"scaler = StandardScaler()\n",
"X_train_encoded[numerical_features] = scaler.fit_transform(X_train_encoded[numerical_features])\n",
"X_val_encoded[numerical_features] = scaler.transform(X_val_encoded[numerical_features])\n",
"X_test_encoded[numerical_features] = scaler.transform(X_test_encoded[numerical_features])\n",
"\n",
"print(X_train_encoded.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"И также попробуем сконструировать признаки, используя фреймворк Featuretools:"
]
},
{
"cell_type": "code",
"execution_count": 147,
2024-11-01 23:33:34 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"index \n",
"0 0 0 0.366350 0.716465 False \n",
"1 0 0 -0.312101 0.912920 True \n",
"2 0 0 -1.026524 -1.612927 False \n",
"3 1 1 1.038358 0.239361 True \n",
"4 1 0 -1.033504 0.477913 False \n",
2024-11-01 23:33:34 +04:00
"\n",
" gender_Other ever_married_Yes work_type_Never_worked \\\n",
"index \n",
"0 False False False \n",
"1 False True False \n",
"2 False True False \n",
"3 False True False \n",
"4 False True False \n",
2024-11-01 23:33:34 +04:00
"\n",
" work_type_Private work_type_Self-employed work_type_children \\\n",
"index \n",
"0 False True False \n",
"1 False True False \n",
"2 False True False \n",
2024-11-01 23:33:34 +04:00
"3 True False False \n",
"4 False False False \n",
2024-11-01 23:33:34 +04:00
"\n",
" Residence_type_Urban smoking_status_formerly smoked \\\n",
"index \n",
"0 True True \n",
"1 False False \n",
"2 False False \n",
"3 True True \n",
"4 False False \n",
2024-11-01 23:33:34 +04:00
"\n",
" smoking_status_never smoked smoking_status_smokes age_bin \\\n",
"index \n",
"0 False False old \n",
"1 False True old \n",
"2 True False old \n",
"3 False False old \n",
"4 True False old \n",
2024-11-01 23:33:34 +04:00
"\n",
" glucose_age_deviation \n",
"index \n",
"0 0.136565 \n",
"1 -0.568466 \n",
"2 -1.310877 \n",
"3 0.834901 \n",
"4 -1.318130 \n"
2024-11-01 23:33:34 +04:00
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Ilya\\Desktop\\AIM\\aimenv\\Lib\\site-packages\\woodwork\\type_sys\\utils.py:33: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
" pd.to_datetime(\n"
]
}
],
"source": [
"data = X_train_encoded.copy() # Используем предобработанные данные\n",
"\n",
"es = ft.EntitySet(id=\"patients\")\n",
"\n",
"es = es.add_dataframe(dataframe_name=\"strokes_data\", dataframe=data, index=\"index\", make_index=True)\n",
"\n",
"feature_matrix, feature_defs = ft.dfs(\n",
" entityset=es, \n",
" target_dataframe_name=\"strokes_data\",\n",
" max_depth=1\n",
")\n",
"\n",
"print(feature_matrix.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Оценим качество набора признаков.\n",
"\n",
"Представим основные оценки качества наборов признаков: \n",
"\n",
"* Предсказательная способность (для задачи классификации) Метрики: Accuracy, Precision, Recall, F1-Score, ROC AUC\n",
2024-11-01 23:33:34 +04:00
"\n",
" Методы: Обучение модели на обучающей выборке и оценка на контрольной и тестовой выборках.\n",
"\n",
"* Скорость вычисления \n",
"\n",
" Методы: Измерение времени выполнения генерации признаков и обучения модели.\n",
"\n",
"* Надежность \n",
"\n",
" Методы: Кросс-валидация, анализ чувствительности модели к изменениям в данных.\n",
"\n",
"* Корреляция \n",
"\n",
" Методы: Анализ корреляционной матрицы признаков, удаление мультиколлинеарных признаков.\n",
"\n",
"* Цельность \n",
"\n",
" Методы: Проверка логической связи между признаками и целевой переменной, интерпретация результатов модели."
]
},
{
"cell_type": "code",
"execution_count": 148,
2024-11-01 23:33:34 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Время обучения модели: 0.64 секунд\n"
2024-11-01 23:33:34 +04:00
]
}
],
"source": [
"X_train_encoded = pd.get_dummies(X_train_encoded, drop_first=True)\n",
"X_val_encoded = pd.get_dummies(X_val_encoded, drop_first=True)\n",
"X_test_encoded = pd.get_dummies(X_test_encoded, drop_first=True)\n",
"\n",
"all_columns = X_train_encoded.columns\n",
"X_train_encoded = X_train_encoded.reindex(columns=all_columns, fill_value=0)\n",
"X_val_encoded = X_val_encoded.reindex(columns=all_columns, fill_value=0)\n",
"X_test_encoded = X_test_encoded.reindex(columns=all_columns, fill_value=0)\n",
"\n",
"# Выбор модели\n",
"model = RandomForestClassifier(n_estimators=100, random_state=42)\n",
2024-11-01 23:33:34 +04:00
"\n",
"# Начинаем отсчет времени\n",
"start_time = time.time()\n",
"model.fit(X_train_encoded, y_train_resampled)\n",
"\n",
"# Время обучения модели\n",
"train_time = time.time() - start_time\n",
"\n",
"print(f'Время обучения модели: {train_time:.2f} секунд')"
2024-11-01 23:33:34 +04:00
]
},
{
"cell_type": "code",
"execution_count": 149,
2024-11-01 23:33:34 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Feature Importance:\n",
" feature importance\n",
"3 bmi 0.186627\n",
"15 glucose_age_deviation 0.185001\n",
"2 avg_glucose_level 0.179305\n",
"17 age_bin_old 0.166728\n",
"0 hypertension 0.040494\n",
"16 age_bin_middle-aged 0.033330\n",
"11 Residence_type_Urban 0.028735\n",
"4 gender_Male 0.028446\n",
"6 ever_married_Yes 0.026005\n",
"1 heart_disease 0.023176\n",
"13 smoking_status_never smoked 0.021729\n",
"14 smoking_status_smokes 0.019693\n",
"8 work_type_Private 0.018582\n",
"9 work_type_Self-employed 0.017155\n",
"12 smoking_status_formerly smoked 0.015585\n",
"10 work_type_children 0.009287\n",
"7 work_type_Never_worked 0.000118\n",
"5 gender_Other 0.000002\n"
2024-11-01 23:33:34 +04:00
]
}
],
"source": [
"# Получение важности признаков\n",
"importances = model.feature_importances_\n",
"feature_names = X_train_encoded.columns\n",
2024-11-01 23:33:34 +04:00
"\n",
"# Сортировка признаков по важности\n",
"feature_importance = pd.DataFrame({'feature': feature_names, 'importance': importances})\n",
"feature_importance = feature_importance.sort_values(by='importance', ascending=False)\n",
2024-11-01 23:33:34 +04:00
"\n",
"print(\"Feature Importance:\")\n",
"print(feature_importance)"
]
},
{
"cell_type": "code",
"execution_count": 150,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.9425962165688193\n",
"Precision: 0.1\n",
"Recall: 0.027777777777777776\n",
"F1 Score: 0.043478260869565216\n",
"ROC AUC: 0.5077287246178417\n",
"Cross-validated Accuracy: 0.9926410942926067\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABCEAAAIjCAYAAAA9agHPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAADepUlEQVR4nOzdd1zW1f//8ccFKHuIooKigCKiooJ7m1LuHKVpFG6z3GaalQNz5ciRqaUJWo7K1Pzkyj3QFE3J1HAiVhiVCuJAGb8//Pn+dgUqoEHa8367Xbcb1znnfc7r/b7IW9eLM0wZGRkZiIiIiIiIiIj8wyzyOwARERERERER+W9QEkJERERERERE8oSSECIiIiIiIiKSJ5SEEBEREREREZE8oSSEiIiIiIiIiOQJJSFEREREREREJE8oCSEiIiIiIiIieUJJCBERERERERHJE0pCiIiIiIiIiEieUBJCRERERERERPKEkhAiIiLynxYREYHJZMry9eabb/4jY+7du5exY8dy5cqVf6T/h3H3eRw8eDC/Q8m1uXPnEhERkd9hiIhIFqzyOwARERGRf4Nx48bh7e1tVlapUqV/ZKy9e/cSFhZGt27dcHFx+UfG+C+bO3cuRYoUoVu3bvkdioiI/I2SECIiIiJAixYtqF69en6H8VCuXbuGvb19foeRb65fv46dnV1+hyEiIveh5RgiIiIi2bBhwwYaNGiAvb09jo6OtGrVimPHjpm1+eGHH+jWrRs+Pj7Y2NhQvHhxevTowZ9//mm0GTt2LG+88QYA3t7extKP2NhYYmNjMZlMWS4lMJlMjB071qwfk8nE8ePHefHFFylUqBD169c36j/77DOqVauGra0trq6udO7cmQsXLuTq3rt164aDgwNxcXG0bt0aBwcHSpQowYcffgjA0aNHadKkCfb29pQuXZply5aZXX93iceuXbt45ZVXKFy4ME5OToSGhnL58uVM482dO5eKFStibW2Nh4cH/fr1y7R0pXHjxlSqVIlDhw7RsGFD7OzseOutt/Dy8uLYsWPs3LnTeLaNGzcG4NKlSwwbNoyAgAAcHBxwcnKiRYsWREdHm/W9Y8cOTCYTX3zxBRMmTKBkyZLY2NjQtGlTTp8+nSne/fv307JlSwoVKoS9vT2VK1dm1qxZZm1++uknnn/+eVxdXbGxsaF69eqsXbs2px+FiMhjTzMhRERERIDExET++OMPs7IiRYoA8Omnn9K1a1eaNWvGe++9x/Xr15k3bx7169fn8OHDeHl5AbB582bOnj1L9+7dKV68OMeOHePjjz/m2LFjfPfdd5hMJjp06MDJkydZvnw5M2bMMMZwc3Pj999/z3HcHTt2xNfXl4kTJ5KRkQHAhAkTGDVqFJ06daJXr178/vvvfPDBBzRs2JDDhw/naglIWloaLVq0oGHDhkyZMoWlS5fSv39/7O3tefvttwkJCaFDhw7Mnz+f0NBQ6tSpk2l5S//+/XFxcWHs2LHExMQwb948zp8/b3zphzvJlbCwMIKDg3n11VeNdlFRUURGRlKgQAGjvz///JMWLVrQuXNnXnrpJYoVK0bjxo0ZMGAADg4OvP322wAUK1YMgLNnz7JmzRo6duyIt7c3v/32Gx999BGNGjXi+PHjeHh4mMU7efJkLCwsGDZsGImJiUyZMoWQkBD2799vtNm8eTOtW7fG3d2dQYMGUbx4cU6cOME333zDoEGDADh27Bj16tWjRIkSvPnmm9jb2/PFF1/Qrl07vvrqK9q3b5/jz0NE5LGVISIiIvIfFh4engFk+crIyMi4evVqhouLS0bv3r3Nrrt48WKGs7OzWfn169cz9b98+fIMIGPXrl1G2dSpUzOAjHPnzpm1PXfuXAaQER4enqkfIGPMmDHG+zFjxmQAGV26dDFrFxsbm2FpaZkxYcIEs/KjR49mWFlZZSq/1/OIiooyyrp27ZoBZEycONEou3z5coatrW2GyWTKWLFihVH+008/ZYr1bp/VqlXLuHXrllE+ZcqUDCDj66+/zsjIyMhISEjIKFiwYMYzzzyTkZaWZrSbM2dOBpCxaNEio6xRo0YZQMb8+fMz3UPFihUzGjVqlKn85s2bZv1mZNx55tbW1hnjxo0zyrZv354BZPj7+2ekpKQY5bNmzcoAMo4ePZqRkZGRkZqamuHt7Z1RunTpjMuXL5v1m56ebvzctGnTjICAgIybN2+a1detWzfD19c3U5wiIk8yLccQERERAT788EM2b95s9oI7f+m+cuUKXbp04Y8//jBelpaW1KpVi+3btxt92NraGj/fvHmTP/74g9q1awPw/fff/yNx9+3b1+z9qlWrSE9Pp1OnTmbxFi9eHF9fX7N4c6pXr17Gzy4uLvj5+WFvb0+nTp2Mcj8/P1xcXDh79mym6/v06WM2k+HVV1/FysqK9evXA7BlyxZu3brF4MGDsbD4v/9N7d27N05OTqxbt86sP2tra7p3757t+K2trY1+09LS+PPPP3FwcMDPzy/Lz6d79+4ULFjQeN+gQQMA494OHz7MuXPnGDx4cKbZJXdndly6dIlt27bRqVMnrl69anwef/75J82aNePUqVP88ssv2b4HEZHHnZZjiIiIiAA1a9bMcmPKU6dOAdCkSZMsr3NycjJ+vnTpEmFhYaxYsYKEhASzdomJiY8w2v/z9yUPp06dIiMjA19f3yzb/zUJkBM2Nja4ubmZlTk7O1OyZEnjC/dfy7Pa6+HvMTk4OODu7k5sbCwA58+fB+4kMv6qYMGC+Pj4GPV3lShRwixJ8CDp6enMmjWLuXPncu7cOdLS0oy6woULZ2pfqlQps/eFChUCMO7tzJkzwP1PUTl9+jQZGRmMGjWKUaNGZdkmISGBEiVKZPs+REQeZ0pCiIiIiNxHeno6cGdfiOLFi2eqt7L6v/+d6tSpE3v37uWNN96gatWqODg4kJ6eTvPmzY1+7ufvX+bv+uuX5b/76+yLu/GaTCY2bNiApaVlpvYODg4PjCMrWfV1v/KM/78/xT/p7/f+IBMnTmTUqFH06NGDd999F1dXVywsLBg8eHCWn8+juLe7/Q4bNoxmzZpl2aZs2bLZ7k9E5HGnJISIiIjIfZQpUwaAokWLEhwcfM92ly9fZuvWrYSFhTF69Gij/O5Mir+6V7Lh7l/a/34SxN9nADwo3oyMDLy9vSlXrly2r8sLp06d4qmnnjLeJycnEx8fT8uWLQEoXbo0ADExMfj4+Bjtbt26xblz5+77/P/qXs935cqVPPXUU3zyySdm5VeuXDE2CM2Ju78bP/744z1ju3sfBQoUyHb8IiJPMu0JISIiInIfzZo1w8nJiYkTJ3L79u1M9XdPtLj7V/O//5V85syZma6xt7cHMicbnJycKFKkCLt27TIrnzt3brbj7dChA5aWloSFhWWKJSMjw+y40Lz28ccfmz3DefPmkZqaSosWLQAIDg6mYMGCzJ492yz2Tz75hMTERFq1apWtcezt7TM9W7jzGf39mXz55Ze53pMhKCgIb29vZs6cmWm8u+MULVqUxo0b89FHHxEfH5+pj9yciCIi8jjTTAgRERGR+3BycmLevHm8/PLLBAUF0blzZ9zc3IiLi2PdunXUq1ePOXPm4OTkZBxfefv2bUqUKMG3337LuXPnMvVZrVo1AN5++206d+5MgQIFaNOmDfb29vTq1YvJkyfTq1cvqlevzq5duzh58mS24y1Tpgzjx49n5MiRxMbG0q5dOxwdHTl37hyrV6+mT58+DBs27JE9n5y4desWTZs2pVOnTsTExDB37lzq16/Ps88+C9w5pnTkyJGEhYXRvHlznn32WaNdjRo1eOmll7I1TrVq1Zg3bx7jx4+nbNmyFC1alCZNmtC6dWvGjRtH9+7dqVu3LkePHmXp0qVmsy5ywsLCgnnz5tGmTRuqVq1K9+7dcXd356effuLYsWNs2rQJuLPpaf369QkICKB37974+Pjw22+/sW/fPn7++Weio6NzNb6IyONISQgRERGRB3jxxRfx8PBg8uTJTJ06lZSUFEqUKEGDBg3MTmdYtmwZAwYM4MMPPyQjI4NnnnmGDRs
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Accuracy: 1.0\n",
"Train Precision: 1.0\n",
"Train Recall: 1.0\n",
"Train F1 Score: 1.0\n",
"Train ROC AUC: 1.0\n"
]
}
],
"source": [
2024-11-01 23:33:34 +04:00
"# Предсказание и оценка\n",
"y_pred = model.predict(X_test_encoded)\n",
"\n",
"accuracy = accuracy_score(y_test, y_pred)\n",
"precision = precision_score(y_test, y_pred)\n",
"recall = recall_score(y_test, y_pred)\n",
"f1 = f1_score(y_test, y_pred)\n",
"roc_auc = roc_auc_score(y_test, y_pred)\n",
2024-11-01 23:33:34 +04:00
"\n",
"print(f\"Accuracy: {accuracy}\")\n",
"print(f\"Precision: {precision}\")\n",
"print(f\"Recall: {recall}\")\n",
"print(f\"F1 Score: {f1}\")\n",
"print(f\"ROC AUC: {roc_auc}\")\n",
2024-11-01 23:33:34 +04:00
"\n",
"# Кросс-валидация\n",
"scores = cross_val_score(model, X_train_encoded, y_train_resampled, cv=5, scoring='accuracy')\n",
"accuracy_cv = scores.mean()\n",
"print(f\"Cross-validated Accuracy: {accuracy_cv}\")\n",
"\n",
"# Анализ важности признаков\n",
"feature_importances = model.feature_importances_\n",
"feature_names = X_train_encoded.columns\n",
"\n",
"importance_df = pd.DataFrame({'Feature': feature_names, 'Importance': feature_importances})\n",
"importance_df = importance_df.sort_values(by='Importance', ascending=False)\n",
"\n",
"plt.figure(figsize=(10, 6))\n",
"sns.barplot(x='Importance', y='Feature', data=importance_df)\n",
"plt.title('Feature Importance')\n",
"plt.show()\n",
2024-11-01 23:33:34 +04:00
"\n",
"# Проверка на переобучение\n",
"y_train_pred = model.predict(X_train_encoded)\n",
"\n",
"accuracy_train = accuracy_score(y_train_resampled, y_train_pred)\n",
"precision_train = precision_score(y_train_resampled, y_train_pred)\n",
"recall_train = recall_score(y_train_resampled, y_train_pred)\n",
"f1_train = f1_score(y_train_resampled, y_train_pred)\n",
"roc_auc_train = roc_auc_score(y_train_resampled, y_train_pred)\n",
"\n",
"print(f\"Train Accuracy: {accuracy_train}\")\n",
"print(f\"Train Precision: {precision_train}\")\n",
"print(f\"Train Recall: {recall_train}\")\n",
"print(f\"Train F1 Score: {f1_train}\")\n",
"print(f\"Train ROC AUC: {roc_auc_train}\")"
2024-11-01 23:33:34 +04:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "aimenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}