1113 lines
188 KiB
Plaintext
Raw Permalink Normal View History

2024-11-29 18:17:46 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Вариант 4. Данные по инсультам"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['id', 'gender', 'age', 'hypertension', 'heart_disease', 'ever_married',\n",
" 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi',\n",
" 'smoking_status', 'stroke'],\n",
" dtype='object')\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>9046</td>\n",
" <td>Male</td>\n",
" <td>67.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>228.69</td>\n",
" <td>36.6</td>\n",
" <td>formerly smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>51676</td>\n",
" <td>Female</td>\n",
" <td>61.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>202.21</td>\n",
" <td>NaN</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>31112</td>\n",
" <td>Male</td>\n",
" <td>80.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>105.92</td>\n",
" <td>32.5</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>60182</td>\n",
" <td>Female</td>\n",
" <td>49.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>171.23</td>\n",
" <td>34.4</td>\n",
" <td>smokes</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1665</td>\n",
" <td>Female</td>\n",
" <td>79.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>174.12</td>\n",
" <td>24.0</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id gender age hypertension heart_disease ever_married \\\n",
"0 9046 Male 67.0 0 1 Yes \n",
"1 51676 Female 61.0 0 0 Yes \n",
"2 31112 Male 80.0 0 1 Yes \n",
"3 60182 Female 49.0 0 0 Yes \n",
"4 1665 Female 79.0 1 0 Yes \n",
"\n",
" work_type Residence_type avg_glucose_level bmi smoking_status \\\n",
"0 Private Urban 228.69 36.6 formerly smoked \n",
"1 Self-employed Rural 202.21 NaN never smoked \n",
"2 Private Rural 105.92 32.5 never smoked \n",
"3 Private Urban 171.23 34.4 smokes \n",
"4 Self-employed Rural 174.12 24.0 never smoked \n",
"\n",
" stroke \n",
"0 1 \n",
"1 1 \n",
"2 1 \n",
"3 1 \n",
"4 1 "
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"from imblearn.over_sampling import RandomOverSampler\n",
"from sklearn.preprocessing import StandardScaler\n",
"import featuretools as ft\n",
"import time\n",
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.model_selection import cross_val_score\n",
"\n",
"df = pd.read_csv(\"../data/healthcare-dataset-stroke-data.csv\")\n",
"\n",
"print(df.columns)\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Бизнес цели и цели технического проекта.\n",
"## Бизнес цели:\n",
"### 1. Предсказание инсульта: Разработать систему, которая сможет предсказать вероятность инсульта у пациентов на основе их медицинских и социальных данных. Это может помочь медицинским учреждениям и специалистам в более раннем выявлении пациентов с высоким риском.\n",
"### 2. Снижение затрат на лечение: Предупреждение инсультов у пациентов позволит снизить затраты на лечение и реабилитацию. Это также поможет улучшить качество медицинских услуг и повысить удовлетворенность пациентов.\n",
"### 3. Повышение эффективности профилактики: Выявление факторов риска инсульта на ранней стадии может способствовать более эффективному проведению профилактических мероприятий.\n",
"## Цели технического проекта:\n",
"### 1. Создание и обучение модели машинного обучения: Разработка модели, способной предсказать вероятность инсульта на основе данных о пациентах (например, возраст, уровень глюкозы, наличие сердечно-сосудистых заболеваний, тип работы, индекс массы тела и т.д.).\n",
"### 2. Анализ и обработка данных: Провести предобработку данных (очистка, заполнение пропущенных значений, кодирование категориальных признаков), чтобы улучшить качество и надежность модели.\n",
"### 3. Оценка модели: Использовать метрики, такие как точность, полнота и F1-мера, чтобы оценить эффективность модели и минимизировать риск ложных положительных и ложных отрицательных предсказаний."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"id 0\n",
"gender 0\n",
"age 0\n",
"hypertension 0\n",
"heart_disease 0\n",
"ever_married 0\n",
"work_type 0\n",
"Residence_type 0\n",
"avg_glucose_level 0\n",
"bmi 201\n",
"smoking_status 0\n",
"stroke 0\n",
"dtype: int64\n",
"\n",
"id False\n",
"gender False\n",
"age False\n",
"hypertension False\n",
"heart_disease False\n",
"ever_married False\n",
"work_type False\n",
"Residence_type False\n",
"avg_glucose_level False\n",
"bmi True\n",
"smoking_status False\n",
"stroke False\n",
"dtype: bool\n",
"\n",
"bmi процент пустых значений: %3.93\n"
]
}
],
"source": [
"print(df.isnull().sum())\n",
"print()\n",
"\n",
"print(df.isnull().any())\n",
"print()\n",
"\n",
"for i in df.columns:\n",
" null_rate = df[i].isnull().sum() / len(df) * 100\n",
" if null_rate > 0:\n",
" print(f\"{i} процент пустых значений: %{null_rate:.2f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Видим пустые значения в bmi, заменяем их"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Количество пустых значений в каждом столбце после замены:\n",
"id 0\n",
"gender 0\n",
"age 0\n",
"hypertension 0\n",
"heart_disease 0\n",
"ever_married 0\n",
"work_type 0\n",
"Residence_type 0\n",
"avg_glucose_level 0\n",
"bmi 0\n",
"smoking_status 0\n",
"stroke 0\n",
"dtype: int64\n"
]
}
],
"source": [
"df[\"bmi\"] = df[\"bmi\"].fillna(df[\"bmi\"].median())\n",
"\n",
"missing_values = df.isnull().sum()\n",
"\n",
"print(\"Количество пустых значений в каждом столбце после замены:\")\n",
"print(missing_values)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['gender', 'age', 'hypertension', 'heart_disease', 'ever_married',\n",
" 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi',\n",
" 'smoking_status', 'stroke'],\n",
" dtype='object')\n"
]
}
],
"source": [
"df = df.drop('id', axis = 1)\n",
"print(df.columns)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Создаем выборки"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Размер обучающей выборки: (2503, 10)\n",
"Размер контрольной выборки: (1074, 10)\n",
"Размер тестовой выборки: (1533, 10)\n"
]
}
],
"source": [
"# Разделим данные на признак (X) и переменую (Y)\n",
"# Начнем со stroke\n",
"X = df.drop(columns=['stroke'])\n",
"y = df['stroke']\n",
"\n",
"# Разбиваем на обучающую и тестовую выборки\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)\n",
"\n",
"# Разбиваем на обучающую и контрольную выборки\n",
"X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.3)\n",
"\n",
"print(\"Размер обучающей выборки: \", X_train.shape)\n",
"print(\"Размер контрольной выборки: \", X_val.shape)\n",
"print(\"Размер тестовой выборки: \", X_test.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Оценим сбалансированность сборок"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Распределение классов в обучающей выборке:\n",
"stroke\n",
"0 0.94966\n",
"1 0.05034\n",
"Name: proportion, dtype: float64\n",
"\n",
"Распределение классов в контрольной выборке:\n",
"stroke\n",
"0 0.946927\n",
"1 0.053073\n",
"Name: proportion, dtype: float64\n",
"\n",
"Распределение классов в тестовой выборке:\n",
"stroke\n",
"0 0.956947\n",
"1 0.043053\n",
"Name: proportion, dtype: float64\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABboAAAHyCAYAAAAtJXgGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABpg0lEQVR4nO3deZyNdf/H8feZMftYYlZLM7aMncYSkm0YjK3uyHJnqFCo0K2iMmiZpEQoS9FCd6LoLrJGSRMhRSQxsg9jG+sMc76/Pzzm/BznDGPJmUuv5+NxHo853/O9rutznTPnfK/zPtf5HpsxxggAAAAAAAAAAIvy8nQBAAAAAAAAAABcD4JuAAAAAAAAAIClEXQDAAAAAAAAACyNoBsAAAAAAAAAYGkE3QAAAAAAAAAASyPoBgAAAAAAAABYGkE3AAAAAAAAAMDSCLoBAAAAAAAAAJZG0A0AAAAA18Butys9PV07duzwdCkAAAD/eATdAAAAAJBHBw4c0IABAxQVFSVfX1+FhoaqUqVKysjI8HRpAAAA/2gFPF0AAADAjfb++++rZ8+ejut+fn66/fbb1aJFC73wwgsKDw/3YHUArOrPP/9UkyZNdO7cOT3xxBO68847VaBAAQUEBCgoKMjT5QEAAPyjEXQDAIBb1siRI1W6dGmdPXtW33//vd555x0tWLBAmzZtUmBgoKfLA2Axffr0ka+vr3788UeVKFHC0+UAAADgIgTdAADgltWqVSvVqlVLkvTII4+oWLFiGjNmjL744gt16dLFw9UBsJJ169bpm2++0eLFiwm5AQAA8iHm6AYAAP8YTZs2lSSlpqZKko4cOaL//Oc/qlq1qoKDg1WoUCG1atVKv/zyi8uyZ8+e1fDhw3XHHXfI399fkZGRuu+++7R9+3ZJ0s6dO2Wz2XK9NG7c2LGuFStWyGazadasWRo6dKgiIiIUFBSkdu3aaffu3S7bXr16tVq2bKnChQsrMDBQjRo10qpVq9zuY+PGjd1uf/jw4S59Z8yYodjYWAUEBKho0aLq3Lmz2+1fbt8uZrfbNXbsWFWuXFn+/v4KDw9Xnz59dPToUad+0dHRatOmjct2+vfv77JOd7WPHj3a5T6VpMzMTCUlJalcuXLy8/NTqVKl9PTTTyszM9PtfXWxS++3kJAQJSQkaNOmTXlatkqVKlq3bp3q16+vgIAAlS5dWpMmTXLql5WVpWHDhik2NlaFCxdWUFCQGjZsqOXLlzv127p1q5o2baqIiAjHfjz66KM6cuSIy7Z79Ohxxce7R48eio6Odlpu9+7dCggIkM1m086dOyX9/+P8/vvvO/UdPny428elf//+LvW0adPGaVs563z99ddzufdc1z99+nTZbDZNmzbNqd8rr7wim82mBQsW5Lou6cL/V8794OXlpYiICD3wwAPatWvXddX1448/yt/fX9u3b1flypXl5+eniIgI9enTx+1jM3v2bMfzKyQkRP/+97+1d+9epz49evRQcHCwduzYofj4eAUFBal48eIaOXKkjDEu9V782Jw4cUKxsbEqXbq09u/f72h//fXXVb9+fRUrVkwBAQGKjY3VnDlznLZ7vfcxAABAfsQZ3QAA4B8jJ5QuVqyYJGnHjh2aN2+eOnbsqNKlSystLU2TJ09Wo0aNtHnzZhUvXlySlJ2drTZt2mjZsmXq3LmznnzySZ04cUJLlizRpk2bVLZsWcc2unTpotatWzttd8iQIW7refnll2Wz2fTMM8/o4MGDGjt2rOLi4rRhwwYFBARIkr755hu1atVKsbGxSkpKkpeXl6ZPn66mTZtq5cqVqlOnjst6S5YsqeTkZEnSyZMn9dhjj7nd9gsvvKBOnTrpkUce0aFDhzR+/Hjdc889+vnnn1WkSBGXZXr37q2GDRtKkj7//HPNnTvX6fY+ffo45kd/4oknlJqaqgkTJujnn3/WqlWr5OPj4/Z+uBrHjh1z7NvF7Ha72rVrp++//169e/dWxYoVtXHjRr355pv6448/NG/evCuuOyYmRs8995yMMdq+fbvGjBmj1q1bOwWkuTl69Khat26tTp06qUuXLvr000/12GOPydfXVw899JAkKSMjQ++++666dOmiXr166cSJE3rvvfcUHx+vNWvWqEaNGpKkU6dOqWTJkmrbtq0KFSqkTZs2aeLEidq7d6++/PJLl22HhITozTffdFx/8MEHr1jvsGHDdPbs2Sv284SePXvq888/16BBg9S8eXOVKlVKGzdu1IgRI/Twww+7PL/cadiwoXr37i273a5NmzZp7Nix2rdvn1auXHnNdR0+fFhnz57VY489pqZNm+rRRx/V9u3bNXHiRK1evVqrV6+Wn5+fpP//nYDatWsrOTlZaWlpGjdunFatWuXy/MrOzlbLli1111136bXXXtPChQuVlJSk8+fPa+TIkW5rOXfunP71r39p165dWrVqlSIjIx23jRs3Tu3atVO3bt2UlZWlTz75RB07dtRXX32lhISEG3YfAwAA5DsGAADgFjN9+nQjySxdutQcOnTI7N6923zyySemWLFiJiAgwOzZs8cYY8zZs2dNdna207KpqanGz8/PjBw50tE2bdo0I8mMGTPGZVt2u92xnCQzevRolz6VK1c2jRo1clxfvny5kWRKlChhMjIyHO2ffvqpkWTGjRvnWHf58uVNfHy8YzvGGHP69GlTunRp07x5c5dt1a9f31SpUsVx/dChQ0aSSUpKcrTt3LnTeHt7m5dfftlp2Y0bN5oCBQq4tG/bts1IMh988IGjLSkpyVx8KLly5UojycycOdNp2YULF7q0R0VFmYSEBJfa+/XrZy49PL209qefftqEhYWZ2NhYp/v0o48+Ml5eXmblypVOy0+aNMlIMqtWrXLZ3sUaNWrktD5jjBk6dKiRZA4ePHjFZSWZN954w9GWmZlpatSoYcLCwkxWVpYxxpjz58+bzMxMp2WPHj1qwsPDzUMPPXTZbfTt29cEBwe7tHfr1s2ULl3aqe3S+ywxMdFERUU5rm/atMl4eXmZVq1aGUkmNTXVGGPMX3/9ZSSZadOmOa3v0sc6Zxv9+vVzqSchIcFpW5d7Xlxu/fv37zdFixY1zZs3N5mZmaZmzZrm9ttvN8ePH891PTmioqJMYmKiU1vXrl1NYGDgddWVc71Zs2bm/Pnzjvac15vx48cbY4zJysoyYWFhpkqVKubMmTOOfl999ZWRZIYNG+ZoS0xMNJLM448/7miz2+0mISHB+Pr6mkOHDjnVO336dGO32023bt1MYGCgWb16tUvdp0+fdrqelZVlqlSpYpo2berUfj33MQAAQH7E1CUAAOCWFRcXp9DQUJUqVUqdO3dWcHCw5s6d65hf18/PT15eFw6HsrOzdfjwYQUHB6tChQpav369Yz2fffaZQkJC9Pjjj7ts49IpHa5G9+7dVbBgQcf1+++/X5GRkY5pAzZs2KBt27apa9euOnz4sNLT05Wenq5Tp06pWbNm+u6772S3253WefbsWfn7+192u59//rnsdrs6derkWGd6eroiIiJUvnx5l6k0srKyJMlxtqo7s2fPVuHChdW8eXOndcbGxio4ONhlnefOnXPql56efsUzjPfu3avx48frhRdeUHBwsMv2K1asqJiYGKd15kxXc+n23cmp6dChQ0pJSdHcuXNVrVo1hYSEXHHZAgUKqE+fPo7rvr6+6tOnjw4ePKh169ZJkry9veXr6yvpwhnoR44c0fnz51WrVi2n/7ccx48fV1pampYtW6b58+frnnvucemTlZV12cfFnSFDhujOO+9Ux44dndpDQ0MlSXv27MnTes6ePevyGJ47d85t39OnTys9PV1Hjx51mpIjNxEREZo4caKWLFmihg0basOGDZo2bZoKFSqUp9oyMzOVnp6ugwcPasmSJfrmm2/UrFmz665LkgYNGiRvb2/H9QcffFDh4eGaP3++JGnt2rU6ePCg+vbt6/RcTEhIUExMjKPfxS6eBiZnWpisrCwtXbrUpe/gwYM1c+ZMffrpp26/0ZHzbRDpwjcNjh8/roYNG7r8j13vfQwAAJDfMHUJAAC4ZU2
"text/plain": [
"<Figure size 1800x500 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from locale import normalize\n",
"\n",
"\n",
"def analyze_balance(y_train, y_val, y_test, y_name):\n",
" print(\"Распределение классов в обучающей выборке:\")\n",
" print(y_train.value_counts(normalize=True))\n",
"\n",
" print(\"\\nРаспределение классов в контрольной выборке:\")\n",
" print(y_val.value_counts(normalize=True))\n",
"\n",
" print(\"\\nРаспределение классов в тестовой выборке:\")\n",
" print(y_test.value_counts(normalize=True))\n",
"\n",
" fig, axes = plt.subplots(1, 3, figsize=(18,5), sharey=True)\n",
" fig.suptitle('Распределение в различных выборках')\n",
"\n",
" sns.barplot(x=y_train.value_counts().index, y=y_train.value_counts(normalize=True), ax=axes[0])\n",
" axes[0].set_title('Обучающая выборка')\n",
" axes[0].set_xlabel(y_name)\n",
" axes[0].set_ylabel('Доля')\n",
"\n",
" sns.barplot(x=y_val.value_counts().index, y=y_val.value_counts(normalize=True), ax=axes[1])\n",
" axes[1].set_title('Контрольная выборка')\n",
" axes[1].set_xlabel(y_name)\n",
"\n",
" sns.barplot(x=y_test.value_counts().index, y=y_test.value_counts(normalize=True), ax=axes[2])\n",
" axes[2].set_title('Тестовая выборка')\n",
" axes[2].set_xlabel(y_name)\n",
"\n",
" plt.show()\n",
"\n",
"analyze_balance(y_train, y_val, y_test, 'stroke')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Заметим, что выборки не сбалансированы. Для балансировки будем использовать RandomOverSampler"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Распределение классов в обучающей выборке:\n",
"stroke\n",
"0 0.5\n",
"1 0.5\n",
"Name: proportion, dtype: float64\n",
"\n",
"Распределение классов в контрольной выборке:\n",
"stroke\n",
"0 0.5\n",
"1 0.5\n",
"Name: proportion, dtype: float64\n",
"\n",
"Распределение классов в тестовой выборке:\n",
"stroke\n",
"0 0.956947\n",
"1 0.043053\n",
"Name: proportion, dtype: float64\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABboAAAHyCAYAAAAtJXgGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABpm0lEQVR4nO3deZyNdf/H8feZMfsYYlZLM7aMncYSmmzDYGx1R5Y7Q4VChW4VlUHLJCVCWYoWuhNFd5E1SpoIKSKJkX0Y21hnmPP9/eEx5+c4M4wlZy69no/HPB5zvud7XdfnOmfO+V7nPdf5XjZjjBEAAAAAAAAAABbl4e4CAAAAAAAAAAC4HgTdAAAAAAAAAABLI+gGAAAAAAAAAFgaQTcAAAAAAAAAwNIIugEAAAAAAAAAlkbQDQAAAAAAAACwNIJuAAAAAAAAAIClEXQDAAAAAAAAACyNoBsAAAAAroHdbld6erp27Njh7lIAAAD+8Qi6AQAAACCfDhw4oAEDBigyMlLe3t4KCQlR5cqVlZGR4e7SAAAA/tEKubsAAACAG+39999Xz549Hbd9fHx0++23q0WLFnrhhRcUFhbmxuoAWNWff/6pJk2a6Ny5c3riiSd05513qlChQvLz81NAQIC7ywMAAPhHI+gGAAC3rJEjR6pMmTI6e/asvv/+e73zzjtasGCBNm3aJH9/f3eXB8Bi+vTpI29vb/34448qWbKku8sBAADARQi6AQDALatVq1aqXbu2JOmRRx5R8eLFNWbMGH3xxRfq0qWLm6sDYCXr1q3TN998o8WLFxNyAwAAFEDM0Q0AAP4xmjZtKklKTU2VJB05ckT/+c9/VK1aNQUGBiooKEitWrXSL7/84rLs2bNnNXz4cN1xxx3y9fVVRESE7rvvPm3fvl2StHPnTtlstjx/Gjdu7FjXihUrZLPZNGvWLA0dOlTh4eEKCAhQu3bttHv3bpdtr169Wi1btlSRIkXk7++vRo0aadWqVbnuY+PGjXPd/vDhw136zpgxQzExMfLz81OxYsXUuXPnXLd/uX27mN1u19ixY1WlShX5+voqLCxMffr00dGjR536RUVFqU2bNi7b6d+/v8s6c6t99OjRLo+pJGVmZiopKUnly5eXj4+PSpcuraefflqZmZm5PlYXu/RxCw4OVkJCgjZt2pSvZatWrap169apQYMG8vPzU5kyZTRp0iSnfllZWRo2bJhiYmJUpEgRBQQEKDY2VsuXL3fqt3XrVjVt2lTh4eGO/Xj00Ud15MgRl2336NHjis93jx49FBUV5bTc7t275efnJ5vNpp07d0r6/+f5/fffd+o7fPjwXJ+X/v37u9TTpk0bp23lrPP111/P49FzXf/06dNls9k0bdo0p36vvPKKbDabFixYkOe6pAt/XzmPg4eHh8LDw/XAAw9o165d11XXjz/+KF9fX23fvl1VqlSRj4+PwsPD1adPn1yfm9mzZzteX8HBwfr3v/+tvXv3OvXp0aOHAgMDtWPHDsXHxysgIEAlSpTQyJEjZYxxqffi5+bEiROKiYlRmTJltH//fkf766+/rgYNGqh48eLy8/NTTEyM5syZ47Td632MAQAACiLO6AYAAP8YOaF08eLFJUk7duzQvHnz1LFjR5UpU0ZpaWmaPHmyGjVqpM2bN6tEiRKSpOzsbLVp00bLli1T586d9eSTT+rEiRNasmSJNm3apHLlyjm20aVLF7Vu3dppu0OGDMm1npdfflk2m03PPPOMDh48qLFjxyouLk4bNmyQn5+fJOmbb75Rq1atFBMTo6SkJHl4eGj69Olq2rSpVq5cqbp167qst1SpUkpOTpYknTx5Uo899liu237hhRfUqVMnPfLIIzp06JDGjx+ve+65Rz///LOKFi3qskzv3r0VGxsrSfr88881d+5cp/v79OnjmB/9iSeeUGpqqiZMmKCff/5Zq1atkpeXV66Pw9U4duyYY98uZrfb1a5dO33//ffq3bu3KlWqpI0bN+rNN9/UH3/8oXnz5l1x3dHR0XruuedkjNH27ds1ZswYtW7d2ikgzcvRo0fVunVrderUSV26dNGnn36qxx57TN7e3nrooYckSRkZGXr33XfVpUsX9erVSydOnNB7772n+Ph4rVmzRjVr1pQknTp1SqVKlVLbtm0VFBSkTZs2aeLEidq7d6++/PJLl20HBwfrzTffdNx+8MEHr1jvsGHDdPbs2Sv2c4eePXvq888/16BBg9S8eXOVLl1aGzdu1IgRI/Twww+7vL5yExsbq969e8tut2vTpk0aO3as9u3bp5UrV15zXYcPH9bZs2f12GOPqWnTpnr00Ue1fft2TZw4UatXr9bq1avl4+Mj6f+vE1CnTh0lJycrLS1N48aN06pVq1xeX9nZ2WrZsqXuuusuvfbaa1q4cKGSkpJ0/vx5jRw5Mtdazp07p3/961/atWuXVq1apYiICMd948aNU7t27dStWzdlZWXpk08+UceOHfXVV18pISHhhj3GAAAABY4BAAC4xUyfPt1IMkuXLjWHDh0yu3fvNp988okpXry48fPzM3v27DHGGHP27FmTnZ3ttGxqaqrx8fExI0eOdLRNmzbNSDJjxoxx2ZbdbncsJ8mMHj3apU+VKlVMo0aNHLeXL19uJJmSJUuajIwMR/unn35qJJlx48Y51l2hQgUTHx/v2I4xxpw+fdqUKVPGNG/e3GVbDRo0MFWrVnXcPnTokJFkkpKSHG07d+40np6e5uWXX3ZaduPGjaZQoUIu7du2bTOSzAcffOBoS0pKMhcfSq5cudJIMjNnznRaduHChS7tkZGRJiEhwaX2fv36mUsPTy+t/emnnzahoaEmJibG6TH96KOPjIeHh1m5cqXT8pMmTTKSzKpVq1y2d7FGjRo5rc8YY4YOHWokmYMHD15xWUnmjTfecLRlZmaamjVrmtDQUJOVlWWMMeb8+fMmMzPTadmjR4+asLAw89BDD112G3379jWBgYEu7d26dTNlypRxarv0MUtMTDSRkZGO25s2bTIeHh6mVatWRpJJTU01xhjz119/GUlm2rRpTuu79LnO2Ua/fv1c6klISHDa1uVeF5db//79+02xYsVM8+bNTWZmpqlVq5a5/fbbzfHjx/NcT47IyEiTmJjo1Na1a1fj7+9/XXXl3G7WrJk5f/68oz3n/Wb8+PHGGGOysrJMaGioqVq1qjlz5oyj31dffWUkmWHDhjnaEhMTjSTz+OOPO9rsdrtJSEgw3t7e5tChQ071Tp8+3djtdtOtWzfj7+9vVq9e7VL36dOnnW5nZWWZqlWrmqZNmzq1X89jDAAAUBAxdQkAALhlxcXFKSQkRKVLl1bnzp0VGBiouXPnOubX9fHxkYfHhcOh7OxsHT58WIGBgapYsaLWr1/vWM9nn32m4OBgPf744y7buHRKh6vRvXt3FS5c2HH7/vvvV0REhGPagA0bNmjbtm3q2rWrDh8+rPT0dKWnp+vUqVNq1qyZvvvuO9ntdqd1nj17Vr6+vpfd7ueffy673a5OnTo51pmenq7w8HBVqFDBZSqNrKwsSXKcrZqb2bNnq0iRImrevLnTOmNiYhQYGOiyznPnzjn1S09Pv+IZxnv37tX48eP1wgsvKDAw0GX7lSpVUnR0tNM6c6aruXT7ucmp6dChQ0pJSdHcuXNVvXp1BQcHX3HZQoUKqU+fPo7b3t7e6tOnjw4ePKh169ZJkjw9PeXt7S3pwhnoR44c0fnz51W7dm2nv7ccx48fV1pampYtW6b58+frnnvucemTlZV12eclN0OGDNGdd96pjh07OrWHhIRIkvbs2ZOv9Zw9e9blOTx37lyufU+fPq309HQdPXrUaUqOvISHh2vixIlasmSJYmNjtWHDBk2bNk1BQUH5qi0zM1Pp6ek6ePCglixZom+++UbNmjW77rokadCgQfL09HTcfvDBBxUWFqb58+dLktauXauDBw+qb9++Tq/FhIQERUdHO/pd7OJpYHKmhcnKytLSpUtd+g4ePFgzZ87Up59+mus3OnK+DSJd+KbB8ePHFRsb6/I3dr2PMQAAQEHD1CUAAOCWNXH
"text/plain": [
"<Figure size 1800x500 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"randoversamp = RandomOverSampler(random_state=42)\n",
"\n",
"# Применение RandomOverSampler для балансировки выборок\n",
"X_train_resampled, y_train_resampled = randoversamp.fit_resample(X_train, y_train)\n",
"X_val_resampled, y_val_resampled = randoversamp.fit_resample(X_val, y_val)\n",
"\n",
"# Проверка сбалансированности после RandomOverSampler\n",
"analyze_balance(y_train_resampled, y_val_resampled, y_test, \"stroke\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Выборки сбалансированы"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Применим унитарное кодирование категориальных признаков (one-hot encoding), переведя их в бинарные вектора."
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" age hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"0 67.0 0 0 190.70 36.0 True \n",
"1 72.0 0 0 99.73 36.7 True \n",
"2 74.0 1 1 70.09 27.4 True \n",
"3 42.0 0 0 59.43 25.4 False \n",
"4 60.0 0 0 69.53 26.2 True \n",
"\n",
" ever_married_Yes work_type_Never_worked work_type_Private \\\n",
"0 True False True \n",
"1 True False False \n",
"2 True False True \n",
"3 True False False \n",
"4 True False False \n",
"\n",
" work_type_Self-employed work_type_children Residence_type_Urban \\\n",
"0 False False True \n",
"1 True False False \n",
"2 False False False \n",
"3 False False True \n",
"4 True False True \n",
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"0 True False \n",
"1 True False \n",
"2 False True \n",
"3 False True \n",
"4 False True \n",
"\n",
" smoking_status_smokes \n",
"0 False \n",
"1 False \n",
"2 False \n",
"3 False \n",
"4 False \n"
]
}
],
"source": [
"# Определение категориальных признаков\n",
"categorical_features = [\n",
" \"gender\",\n",
" \"ever_married\",\n",
" \"work_type\",\n",
" \"Residence_type\",\n",
" \"smoking_status\",\n",
"]\n",
"\n",
"# Применение one-hot encoding к обучающей выборке\n",
"X_train_encoded = pd.get_dummies(\n",
" X_train_resampled, columns=categorical_features, drop_first=True\n",
")\n",
"\n",
"# Применение one-hot encoding к контрольной выборке\n",
"X_val_encoded = pd.get_dummies(\n",
" X_val_resampled, columns=categorical_features, drop_first=True\n",
")\n",
"\n",
"# Применение one-hot encoding к тестовой выборке\n",
"X_test_encoded = pd.get_dummies(X_test, columns=categorical_features, drop_first=True)\n",
"\n",
"print(X_train_encoded.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Перейдем к числовым признакам, а именно к колонке age, применим дискретизацию (позволяет преобразовать данные из числового представления в категориальное):"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"0 0 0 190.70 36.0 True \n",
"1 0 0 99.73 36.7 True \n",
"2 1 1 70.09 27.4 True \n",
"3 0 0 59.43 25.4 False \n",
"4 0 0 69.53 26.2 True \n",
"\n",
" ever_married_Yes work_type_Never_worked work_type_Private \\\n",
"0 True False True \n",
"1 True False False \n",
"2 True False True \n",
"3 True False False \n",
"4 True False False \n",
"\n",
" work_type_Self-employed work_type_children Residence_type_Urban \\\n",
"0 False False True \n",
"1 True False False \n",
"2 False False False \n",
"3 False False True \n",
"4 True False True \n",
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"0 True False \n",
"1 True False \n",
"2 False True \n",
"3 False True \n",
"4 False True \n",
"\n",
" smoking_status_smokes age_bin \n",
"0 False old \n",
"1 False old \n",
"2 False old \n",
"3 False middle-aged \n",
"4 False old \n"
]
}
],
"source": [
"# Определение числовых признаков для дискретизации\n",
"numerical_features = [\"age\"]\n",
"\n",
"\n",
"# Функция для дискретизации числовых признаков\n",
"def discretize_features(df, features, bins, labels):\n",
" for feature in features:\n",
" df[f\"{feature}_bin\"] = pd.cut(df[feature], bins=bins, labels=labels)\n",
" df.drop(columns=[feature], inplace=True)\n",
" return df\n",
"\n",
"\n",
"# Заданные интервалы и метки\n",
"age_bins = [0, 25, 55, 100]\n",
"age_labels = [\"young\", \"middle-aged\", \"old\"]\n",
"\n",
"# Применение дискретизации к обучающей, контрольной и тестовой выборкам\n",
"X_train_encoded = discretize_features(\n",
" X_train_encoded, numerical_features, bins=age_bins, labels=age_labels\n",
")\n",
"X_val_encoded = discretize_features(\n",
" X_val_encoded, numerical_features, bins=age_bins, labels=age_labels\n",
")\n",
"X_test_encoded = discretize_features(\n",
" X_test_encoded, numerical_features, bins=age_bins, labels=age_labels\n",
")\n",
"\n",
"print(X_train_encoded.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Применим ручной синтез признаков. Например, в этом случае создадим признак, в котором вычисляется отклонение уровня глюкозы от среднего для определенной возрастной группы. Вышеуказанный признак может быть полезен для определения пациентов с аномальными данными."
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"0 0 0 190.70 36.0 True \n",
"1 0 0 99.73 36.7 True \n",
"2 1 1 70.09 27.4 True \n",
"3 0 0 59.43 25.4 False \n",
"4 0 0 69.53 26.2 True \n",
"\n",
" ever_married_Yes work_type_Never_worked work_type_Private \\\n",
"0 True False True \n",
"1 True False False \n",
"2 True False True \n",
"3 True False False \n",
"4 True False False \n",
"\n",
" work_type_Self-employed work_type_children Residence_type_Urban \\\n",
"0 False False True \n",
"1 True False False \n",
"2 False False False \n",
"3 False False True \n",
"4 True False True \n",
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"0 True False \n",
"1 True False \n",
"2 False True \n",
"3 False True \n",
"4 False True \n",
"\n",
" smoking_status_smokes age_bin glucose_age_deviation \n",
"0 False old 58.945260 \n",
"1 False old -32.024740 \n",
"2 False old -61.664740 \n",
"3 False middle-aged -46.635693 \n",
"4 False old -62.224740 \n"
]
}
],
"source": [
"age_glucose_mean = X_train_encoded.groupby(\"age_bin\", observed=False)[\n",
" \"avg_glucose_level\"\n",
"].transform(\"mean\")\n",
"X_train_encoded[\"glucose_age_deviation\"] = (\n",
" X_train_encoded[\"avg_glucose_level\"] - age_glucose_mean\n",
")\n",
"\n",
"age_glucose_mean = X_val_encoded.groupby(\"age_bin\", observed=False)[\n",
" \"avg_glucose_level\"\n",
"].transform(\"mean\")\n",
"X_val_encoded[\"glucose_age_deviation\"] = (\n",
" X_val_encoded[\"avg_glucose_level\"] - age_glucose_mean\n",
")\n",
"\n",
"age_glucose_mean = X_test_encoded.groupby(\"age_bin\", observed=False)[\n",
" \"avg_glucose_level\"\n",
"].transform(\"mean\")\n",
"X_test_encoded[\"glucose_age_deviation\"] = (\n",
" X_test_encoded[\"avg_glucose_level\"] - age_glucose_mean\n",
")\n",
"\n",
"print(X_train_encoded.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Используем масштабирование признаков, для приведения всех числовых признаков к одинаковым или очень похожим диапазонам значений/распределениям. \n",
"### Масштабирование признаков позволяет получить более качественную модель за счет снижения доминирования одних признаков над другими."
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"0 0 0 1.274173 0.891831 True \n",
"1 0 0 -0.349199 0.991754 True \n",
"2 1 1 -0.878129 -0.335787 True \n",
"3 0 0 -1.068358 -0.621280 False \n",
"4 0 0 -0.888122 -0.507083 True \n",
"\n",
" ever_married_Yes work_type_Never_worked work_type_Private \\\n",
"0 True False True \n",
"1 True False False \n",
"2 True False True \n",
"3 True False False \n",
"4 True False False \n",
"\n",
" work_type_Self-employed work_type_children Residence_type_Urban \\\n",
"0 False False True \n",
"1 True False False \n",
"2 False False False \n",
"3 False False True \n",
"4 True False True \n",
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"0 True False \n",
"1 True False \n",
"2 False True \n",
"3 False True \n",
"4 False True \n",
"\n",
" smoking_status_smokes age_bin glucose_age_deviation \n",
"0 False old 1.092637 \n",
"1 False old -0.593625 \n",
"2 False old -1.143046 \n",
"3 False middle-aged -0.864461 \n",
"4 False old -1.153427 \n"
]
}
],
"source": [
"numerical_features = [\"avg_glucose_level\", \"bmi\", \"glucose_age_deviation\"]\n",
"\n",
"scaler = StandardScaler()\n",
"X_train_encoded[numerical_features] = scaler.fit_transform(\n",
" X_train_encoded[numerical_features]\n",
")\n",
"X_val_encoded[numerical_features] = scaler.transform(X_val_encoded[numerical_features])\n",
"X_test_encoded[numerical_features] = scaler.transform(\n",
" X_test_encoded[numerical_features]\n",
")\n",
"\n",
"print(X_train_encoded.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Сконструируем признаки, используя фреймворк Featuretools:"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"index \n",
"0 0 0 1.274173 0.891831 True \n",
"1 0 0 -0.349199 0.991754 True \n",
"2 1 1 -0.878129 -0.335787 True \n",
"3 0 0 -1.068358 -0.621280 False \n",
"4 0 0 -0.888122 -0.507083 True \n",
"\n",
" ever_married_Yes work_type_Never_worked work_type_Private \\\n",
"index \n",
"0 True False True \n",
"1 True False False \n",
"2 True False True \n",
"3 True False False \n",
"4 True False False \n",
"\n",
" work_type_Self-employed work_type_children Residence_type_Urban \\\n",
"index \n",
"0 False False True \n",
"1 True False False \n",
"2 False False False \n",
"3 False False True \n",
"4 True False True \n",
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"index \n",
"0 True False \n",
"1 True False \n",
"2 False True \n",
"3 False True \n",
"4 False True \n",
"\n",
" smoking_status_smokes age_bin glucose_age_deviation \n",
"index \n",
"0 False old 1.092637 \n",
"1 False old -0.593625 \n",
"2 False old -1.143046 \n",
"3 False middle-aged -0.864461 \n",
"4 False old -1.153427 \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\Users\\Leo\\AppData\\Local\\pypoetry\\Cache\\virtualenvs\\mai-S9i2J6c7-py3.12\\Lib\\site-packages\\woodwork\\type_sys\\utils.py:33: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
" pd.to_datetime(\n"
]
}
],
"source": [
"data = X_train_encoded.copy() # Используем предобработанные данные\n",
"\n",
"es = ft.EntitySet(id=\"patients\")\n",
"\n",
"es = es.add_dataframe(\n",
" dataframe_name=\"strokes_data\", dataframe=data, index=\"index\", make_index=True\n",
")\n",
"\n",
"feature_matrix, feature_defs = ft.dfs(\n",
" entityset=es, target_dataframe_name=\"strokes_data\", max_depth=1\n",
")\n",
"\n",
"print(feature_matrix.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Оценим качество набора признаков.\n",
"\n",
"1. Предсказательная способность (для задачи классификации)\n",
" - Метрики: Accuracy, Precision, Recall, F1-Score, ROC AUC\n",
" - Методы: Обучение модели на обучающей выборке и оценка на валидационной и тестовой выборках.\n",
"\n",
"2. Вычислительная эффективность\n",
" - Методы: Измерение времени, затраченного на генерацию признаков и обучение модели.\n",
"\n",
"3. Надежность\n",
" - Методы: Кросс-валидация и анализ чувствительности модели к изменениям в данных.\n",
"\n",
"4. Корреляция\n",
" - Методы: Анализ корреляционной матрицы признаков и исключение мультиколлинеарных признаков.\n",
"\n",
"5. Логическая согласованность\n",
" - Методы: Проверка логической связи признаков с целевой переменной и интерпретация результатов модели."
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Время обучения модели: 0.55 секунд\n"
]
}
],
"source": [
"X_train_encoded = pd.get_dummies(X_train_encoded, drop_first=True)\n",
"X_val_encoded = pd.get_dummies(X_val_encoded, drop_first=True)\n",
"X_test_encoded = pd.get_dummies(X_test_encoded, drop_first=True)\n",
"\n",
"all_columns = X_train_encoded.columns\n",
"X_train_encoded = X_train_encoded.reindex(columns=all_columns, fill_value=0)\n",
"X_val_encoded = X_val_encoded.reindex(columns=all_columns, fill_value=0)\n",
"X_test_encoded = X_test_encoded.reindex(columns=all_columns, fill_value=0)\n",
"\n",
"# Выбор модели\n",
"model = RandomForestClassifier(n_estimators=100, random_state=42)\n",
"\n",
"# Начинаем отсчет времени\n",
"start_time = time.time()\n",
"model.fit(X_train_encoded, y_train_resampled)\n",
"\n",
"# Время обучения модели\n",
"train_time = time.time() - start_time\n",
"\n",
"print(f\"Время обучения модели: {train_time:.2f} секунд\")"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Feature Importance:\n",
" feature importance\n",
"16 age_bin_old 0.199473\n",
"3 bmi 0.186638\n",
"14 glucose_age_deviation 0.172977\n",
"2 avg_glucose_level 0.171885\n",
"15 age_bin_middle-aged 0.037563\n",
"5 ever_married_Yes 0.036580\n",
"4 gender_Male 0.028419\n",
"0 hypertension 0.025984\n",
"8 work_type_Self-employed 0.022772\n",
"10 Residence_type_Urban 0.022221\n",
"1 heart_disease 0.020967\n",
"12 smoking_status_never smoked 0.017890\n",
"7 work_type_Private 0.016844\n",
"11 smoking_status_formerly smoked 0.015641\n",
"13 smoking_status_smokes 0.012366\n",
"9 work_type_children 0.011684\n",
"6 work_type_Never_worked 0.000096\n"
]
}
],
"source": [
"# Получение важности признаков\n",
"importances = model.feature_importances_\n",
"feature_names = X_train_encoded.columns\n",
"\n",
"# Сортировка признаков по важности\n",
"feature_importance = pd.DataFrame({\"feature\": feature_names, \"importance\": importances})\n",
"feature_importance = feature_importance.sort_values(by=\"importance\", ascending=False)\n",
"\n",
"print(\"Feature Importance:\")\n",
"print(feature_importance)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.9465101108936725\n",
"Precision: 0.21428571428571427\n",
"Recall: 0.09090909090909091\n",
"F1 Score: 0.1276595744680851\n",
"ROC AUC: 0.5379562496126913\n",
"Cross-validated Accuracy: 0.988641983507665\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABCEAAAIjCAYAAAA9agHPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAADYqUlEQVR4nOzdeVRV1f//8ecFFJBRFBUUBRMR53meUks0zaE0jUJUNPuEY6ZZOeCcs2ZqaYKWQ2VmfnJKTZxTnNCUHFDEkqJPKogmCtzfH/48326gAiqkvR5rnbW4e++zz/uci63Omz2YzGazGRERERERERGRR8wqvwMQERERERERkX8HJSFEREREREREJE8oCSEiIiIiIiIieUJJCBERERERERHJE0pCiIiIiIiIiEieUBJCRERERERERPKEkhAiIiIiIiIikieUhBARERERERGRPKEkhIiIiIiIiIjkCSUhRERERERERCRPKAkhIiIi/2oRERGYTKYsj7fffvuRXHPPnj2MGTOGK1euPJL+H8Sd53HgwIH8DiXX5s2bR0RERH6HISIiWbDJ7wBERERE/gnGjh2Lj4+PRVnlypUfybX27NlDWFgYwcHBuLq6PpJr/JvNmzePokWLEhwcnN+hiIjI3ygJISIiIgK0adOG2rVr53cYD+TatWs4ODjkdxj55vr16xQqVCi/wxARkXvQdAwRERGRbNiwYQNNmjTBwcEBJycnnnvuOY4fP27R5ujRowQHB1O2bFns7OwoUaIEvXr14o8//jDajBkzhrfeegsAHx8fY+pHXFwccXFxmEymLKcSmEwmxowZY9GPyWTixIkTvPzyyxQuXJjGjRsb9Z999hm1atXC3t4eNzc3unXrxoULF3J178HBwTg6OhIfH0+7du1wdHSkZMmSfPjhhwAcO3aMFi1a4ODgQJkyZVi+fLnF+XemeOzYsYPXXnuNIkWK4OzsTFBQEJcvX850vXnz5lGpUiVsbW3x9PTkjTfeyDR1pXnz5lSuXJmDBw/StGlTChUqxDvvvIO3tzfHjx9n+/btxrNt3rw5AJcuXWLo0KFUqVIFR0dHnJ2dadOmDdHR0RZ9R0ZGYjKZ+OKLL5gwYQKlSpXCzs6Oli1bcubMmUzx7tu3j7Zt21K4cGEcHByoWrUqs2fPtmjz008/8eKLL+Lm5oadnR21a9dm7dq1Of0qREQeexoJISIiIgIkJSXxv//9z6KsaNGiAHz66af06NGD1q1b8/7773P9+nXmz59P48aNOXz4MN7e3gBs3ryZs2fP0rNnT0qUKMHx48f5+OOPOX78OD/88AMmk4nOnTtz6tQpVqxYwcyZM41ruLu78/vvv+c47i5duuDr68vEiRMxm80ATJgwgZEjR9K1a1dCQkL4/fff+eCDD2jatCmHDx/O1RSQ9PR02rRpQ9OmTZkyZQrLli0jNDQUBwcH3n33XQIDA+ncuTMLFiwgKCiIBg0aZJreEhoaiqurK2PGjOHkyZPMnz+f8+fPGy/9cDu5EhYWRqtWrXj99deNdlFRUezevZsCBQoY/f3xxx+0adOGbt268corr1C8eHGaN29O//79cXR05N133wWgePHiAJw9e5Y1a9bQpUsXfHx8+O233/joo49o1qwZJ06cwNPT0yLeyZMnY2VlxdChQ0lKSmLKlCkEBgayb98+o83mzZtp164dHh4eDBw4kBIlShATE8O3337LwIEDATh+/DiNGjWiZMmSvP322zg4OPDFF1/QsWNHvvrqKzp16pTj70NE5LFlFhEREfkXCw8PNwNZHmaz2Xz16lWzq6uruU+fPhbn/frrr2YXFxeL8uvXr2fqf8WKFWbAvGPHDqNs6tSpZsB87tw5i7bnzp0zA+bw8PBM/QDm0aNHG59Hjx5tBszdu3e3aBcXF2e2trY2T5gwwaL82LFjZhsbm0zld3seUVFRRlmPHj3MgHnixIlG2eXLl8329vZmk8lkXrlypVH+008/ZYr1Tp+1atUy37x50yifMmWKGTB/8803ZrPZbE5MTDQXLFjQ/Oyzz5rT09ONdnPnzjUD5sWLFxtlzZo1MwPmBQsWZLqHSpUqmZs1a5ap/MaNGxb9ms23n7mtra157NixRtm2bdvMgNnf39+cmppqlM+ePdsMmI8dO2Y2m83mtLQ0s4+Pj7lMmTLmy5cvW/SbkZFh/NyyZUtzlSpVzDdu3LCob9iwodnX1zdTnCIiTzJNxxAREREBPvzwQzZv3mxxwO2/dF+5coXu3bvzv//9zzisra2pV68e27ZtM/qwt7c3fr5x4wb/+9//qF+/PgCHDh16JHH369fP4vPq1avJyMiga9euFvGWKFECX19fi3hzKiQkxPjZ1dUVPz8/HBwc6Nq1q1Hu5+eHq6srZ8+ezXR+3759LUYyvP7669jY2LB+/XoAtmzZws2bNxk0aBBWVv/3v6l9+vTB2dmZdevWWfRna2tLz549sx2/ra2t0W96ejp//PEHjo6O+Pn5Zfn99OzZk4IFCxqfmzRpAmDc2+HDhzl37hyDBg3KNLrkzsiOS5cu8f3339O1a1euXr1qfB9//PEHrVu35vTp0/zyyy/ZvgcRkcedpmOIiIiIAHXr1s1yYcrTp08D0KJFiyzPc3Z2Nn6+dOkSYWFhrFy5ksTERIt2SUlJDzHa//P3KQ+nT5/GbDbj6+ubZfu/JgFyws7ODnd3d4syFxcXSpUqZbxw/7U8q7Ue/h6To6MjHh4exMXFAXD+/HngdiLjrwoWLEjZsmWN+jtKlixpkSS4n4yMDGbPns28efM4d+4c6enpRl2RIkUytS9durTF58KFCwMY9xYbGwvcexeVM2fOYDabGTlyJCNHjsyyTWJiIiVLlsz2fYiIPM6UhBARERG5h4yMDOD2uhAlSpTIVG9j83//O9W1a1f27NnDW2+9RfXq1XF0dCQjI4OAgACjn3v5+8v8HX99Wf67v46+uBOvyWRiw4YNWFtbZ2rv6Oh43ziyklVf9yo3///1KR6lv9/7/UycOJGRI0fSq1cvxo0bh5ubG1ZWVgwaNCjL7+dh3NudfocOHUrr1q2zbFOuXLls9yci8rhTEkJERETkHp566ikAihUrRqtWre7a7vLly2zdupWwsDBGjRpllN8ZSfFXd0s23PlL+993gvj7CID7xWs2m/Hx8aF8+fLZPi8vnD59mqefftr4nJKSQkJCAm3btgWgTJkyAJw8eZKyZcsa7W7evMm5c+fu+fz/6m7Pd9WqVTz99NN88sknFuVXrlwxFgjNiTu/Gz/++ONdY7tzHwUKFMh2/CIiTzKtCSEiIiJyD61bt8bZ2ZmJEydy69atTPV3drS481fzv/+VfNasWZnOcXBwADInG5ydnSlatCg7duywKJ83b1624+3cuTPW1taEhYVlisVsNltsF5rXPv74Y4tnOH/+fNLS0mjTpg0ArVq1omDBgsyZM8ci9k8++YSkpCSee+65bF3HwcEh07OF29/R35/Jl19+mes1GWrWrImPjw+zZs3KdL071ylWrBjNmzfno48+IiEhIVMfudkRRUTkcaaRECIiIiL34OzszPz583n11VepWbMm3bp1w93dnfj4eNatW0ejRo2YO3cuzs7OxvaVt27domTJknz33XecO3cuU5+1atUC4N1336Vbt24UKFCA9u3b4+DgQEhICJMnTyYkJITatWuzY8cOTp06le14n3rqKcaPH8+IESOIi4ujY8eOODk5ce7cOb7++mv69u3L0KFDH9rzyYmbN2/SsmVLunbtysmTJ5k3bx6NGzfm+eefB25vUzpixAjCwsIICAjg+eefN9rVqVOHV155JVvXqVWrFvPnz2f8+PGUK1eOYsWK0aJFC9q1a8fYsWPp2bMnDRs25NixYyxbtsxi1EVOWFlZMX/+fNq3b0/16tXp2bMnHh4e/PTTTxw/fpxNmzYBtxc9bdy4MVWqVKFPnz6ULVuW3377jb179/Lzzz8THR2dq+uLiDyOlIQQERERuY+XX34ZT09PJk+ezNSpU0lNTaVkyZI0adLEYneG5cuX079/fz788EPMZjPPPvssGzZswNPT06K
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Accuracy: 1.0\n",
"Train Precision: 1.0\n",
"Train Recall: 1.0\n",
"Train F1 Score: 1.0\n",
"Train ROC AUC: 1.0\n"
]
}
],
"source": [
"# Предсказание и оценка\n",
"y_pred = model.predict(X_test_encoded)\n",
"\n",
"accuracy = accuracy_score(y_test, y_pred)\n",
"precision = precision_score(y_test, y_pred)\n",
"recall = recall_score(y_test, y_pred)\n",
"f1 = f1_score(y_test, y_pred)\n",
"roc_auc = roc_auc_score(y_test, y_pred)\n",
"\n",
"print(f\"Accuracy: {accuracy}\")\n",
"print(f\"Precision: {precision}\")\n",
"print(f\"Recall: {recall}\")\n",
"print(f\"F1 Score: {f1}\")\n",
"print(f\"ROC AUC: {roc_auc}\")\n",
"\n",
"# Кросс-валидация\n",
"scores = cross_val_score(\n",
" model, X_train_encoded, y_train_resampled, cv=5, scoring=\"accuracy\"\n",
")\n",
"accuracy_cv = scores.mean()\n",
"print(f\"Cross-validated Accuracy: {accuracy_cv}\")\n",
"\n",
"# Анализ важности признаков\n",
"feature_importances = model.feature_importances_\n",
"feature_names = X_train_encoded.columns\n",
"\n",
"importance_df = pd.DataFrame(\n",
" {\"Feature\": feature_names, \"Importance\": feature_importances}\n",
")\n",
"importance_df = importance_df.sort_values(by=\"Importance\", ascending=False)\n",
"\n",
"plt.figure(figsize=(10, 6))\n",
"sns.barplot(x=\"Importance\", y=\"Feature\", data=importance_df)\n",
"plt.title(\"Feature Importance\")\n",
"plt.show()\n",
"\n",
"# Проверка на переобучение\n",
"y_train_pred = model.predict(X_train_encoded)\n",
"\n",
"accuracy_train = accuracy_score(y_train_resampled, y_train_pred)\n",
"precision_train = precision_score(y_train_resampled, y_train_pred)\n",
"recall_train = recall_score(y_train_resampled, y_train_pred)\n",
"f1_train = f1_score(y_train_resampled, y_train_pred)\n",
"roc_auc_train = roc_auc_score(y_train_resampled, y_train_pred)\n",
"\n",
"print(f\"Train Accuracy: {accuracy_train}\")\n",
"print(f\"Train Precision: {precision_train}\")\n",
"print(f\"Train Recall: {recall_train}\")\n",
"print(f\"Train F1 Score: {f1_train}\")\n",
"print(f\"Train ROC AUC: {roc_auc_train}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}