1114 lines
191 KiB
Plaintext
Raw Permalink Normal View History

2024-12-19 17:02:15 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Вариант 4. Данные по инсультам"
]
},
{
"cell_type": "code",
"execution_count": 163,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['id', 'gender', 'age', 'hypertension', 'heart_disease', 'ever_married',\n",
" 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi',\n",
" 'smoking_status', 'stroke'],\n",
" dtype='object')\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>9046</td>\n",
" <td>Male</td>\n",
" <td>67.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>228.69</td>\n",
" <td>36.6</td>\n",
" <td>formerly smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>51676</td>\n",
" <td>Female</td>\n",
" <td>61.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>202.21</td>\n",
" <td>NaN</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>31112</td>\n",
" <td>Male</td>\n",
" <td>80.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>105.92</td>\n",
" <td>32.5</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>60182</td>\n",
" <td>Female</td>\n",
" <td>49.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>171.23</td>\n",
" <td>34.4</td>\n",
" <td>smokes</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1665</td>\n",
" <td>Female</td>\n",
" <td>79.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>174.12</td>\n",
" <td>24.0</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id gender age hypertension heart_disease ever_married \\\n",
"0 9046 Male 67.0 0 1 Yes \n",
"1 51676 Female 61.0 0 0 Yes \n",
"2 31112 Male 80.0 0 1 Yes \n",
"3 60182 Female 49.0 0 0 Yes \n",
"4 1665 Female 79.0 1 0 Yes \n",
"\n",
" work_type Residence_type avg_glucose_level bmi smoking_status \\\n",
"0 Private Urban 228.69 36.6 formerly smoked \n",
"1 Self-employed Rural 202.21 NaN never smoked \n",
"2 Private Rural 105.92 32.5 never smoked \n",
"3 Private Urban 171.23 34.4 smokes \n",
"4 Self-employed Rural 174.12 24.0 never smoked \n",
"\n",
" stroke \n",
"0 1 \n",
"1 1 \n",
"2 1 \n",
"3 1 \n",
"4 1 "
]
},
"execution_count": 163,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"from imblearn.over_sampling import RandomOverSampler\n",
"from sklearn.preprocessing import StandardScaler\n",
"import featuretools as ft\n",
"import time\n",
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.model_selection import cross_val_score\n",
"\n",
"df = pd.read_csv(\"./data/healthcare-dataset-stroke-data.csv\")\n",
"\n",
"print(df.columns)\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Бизнес цели и цели технического проекта.\n",
"## Бизнес цели:\n",
"### 1. Предсказание инсульта: Разработать систему, которая сможет предсказать вероятность инсульта у пациентов на основе их медицинских и социальных данных. Это может помочь медицинским учреждениям и специалистам в более раннем выявлении пациентов с высоким риском.\n",
"### 2. Снижение затрат на лечение: Предупреждение инсультов у пациентов позволит снизить затраты на лечение и реабилитацию. Это также поможет улучшить качество медицинских услуг и повысить удовлетворенность пациентов.\n",
"### 3. Повышение эффективности профилактики: Выявление факторов риска инсульта на ранней стадии может способствовать более эффективному проведению профилактических мероприятий.\n",
"## Цели технического проекта:\n",
"### 1. Создание и обучение модели машинного обучения: Разработка модели, способной предсказать вероятность инсульта на основе данных о пациентах (например, возраст, уровень глюкозы, наличие сердечно-сосудистых заболеваний, тип работы, индекс массы тела и т.д.).\n",
"### 2. Анализ и обработка данных: Провести предобработку данных (очистка, заполнение пропущенных значений, кодирование категориальных признаков), чтобы улучшить качество и надежность модели.\n",
"### 3. Оценка модели: Использовать метрики, такие как точность, полнота и F1-мера, чтобы оценить эффективность модели и минимизировать риск ложных положительных и ложных отрицательных предсказаний."
]
},
{
"cell_type": "code",
"execution_count": 164,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"id 0\n",
"gender 0\n",
"age 0\n",
"hypertension 0\n",
"heart_disease 0\n",
"ever_married 0\n",
"work_type 0\n",
"Residence_type 0\n",
"avg_glucose_level 0\n",
"bmi 201\n",
"smoking_status 0\n",
"stroke 0\n",
"dtype: int64\n",
"\n",
"id False\n",
"gender False\n",
"age False\n",
"hypertension False\n",
"heart_disease False\n",
"ever_married False\n",
"work_type False\n",
"Residence_type False\n",
"avg_glucose_level False\n",
"bmi True\n",
"smoking_status False\n",
"stroke False\n",
"dtype: bool\n",
"\n",
"bmi процент пустых значений: %3.93\n"
]
}
],
"source": [
"print(df.isnull().sum())\n",
"print()\n",
"\n",
"print(df.isnull().any())\n",
"print()\n",
"\n",
"for i in df.columns:\n",
" null_rate = df[i].isnull().sum() / len(df) * 100\n",
" if null_rate > 0:\n",
" print(f\"{i} процент пустых значений: %{null_rate:.2f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Видим пустые значения в bmi, заменяем их"
]
},
{
"cell_type": "code",
"execution_count": 165,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Количество пустых значений в каждом столбце после замены:\n",
"id 0\n",
"gender 0\n",
"age 0\n",
"hypertension 0\n",
"heart_disease 0\n",
"ever_married 0\n",
"work_type 0\n",
"Residence_type 0\n",
"avg_glucose_level 0\n",
"bmi 0\n",
"smoking_status 0\n",
"stroke 0\n",
"dtype: int64\n"
]
}
],
"source": [
"df[\"bmi\"] = df[\"bmi\"].fillna(df[\"bmi\"].median())\n",
"\n",
"missing_values = df.isnull().sum()\n",
"\n",
"print(\"Количество пустых значений в каждом столбце после замены:\")\n",
"print(missing_values)"
]
},
{
"cell_type": "code",
"execution_count": 166,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['gender', 'age', 'hypertension', 'heart_disease', 'ever_married',\n",
" 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi',\n",
" 'smoking_status', 'stroke'],\n",
" dtype='object')\n"
]
}
],
"source": [
"df = df.drop('id', axis = 1)\n",
"print(df.columns)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Создаем выборки"
]
},
{
"cell_type": "code",
"execution_count": 167,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Размер обучающей выборки: (2503, 10)\n",
"Размер контрольной выборки: (1074, 10)\n",
"Размер тестовой выборки: (1533, 10)\n"
]
}
],
"source": [
"# Разделим данные на признак (X) и переменую (Y)\n",
"# Начнем со stroke\n",
"X = df.drop(columns=['stroke'])\n",
"y = df['stroke']\n",
"\n",
"# Разбиваем на обучающую и тестовую выборки\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)\n",
"\n",
"# Разбиваем на обучающую и контрольную выборки\n",
"X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.3)\n",
"\n",
"print(\"Размер обучающей выборки: \", X_train.shape)\n",
"print(\"Размер контрольной выборки: \", X_val.shape)\n",
"print(\"Размер тестовой выборки: \", X_test.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Оценим сбалансированность сборок"
]
},
{
"cell_type": "code",
"execution_count": 168,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Распределение классов в обучающей выборке:\n",
"stroke\n",
"0 0.948861\n",
"1 0.051139\n",
"Name: proportion, dtype: float64\n",
"\n",
"Распределение классов в контрольной выборке:\n",
"stroke\n",
"0 0.947858\n",
"1 0.052142\n",
"Name: proportion, dtype: float64\n",
"\n",
"Распределение классов в тестовой выборке:\n",
"stroke\n",
"0 0.957599\n",
"1 0.042401\n",
"Name: proportion, dtype: float64\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABboAAAHyCAYAAAAtJXgGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABpkUlEQVR4nO3dd3RU1d7G8WcS0kMRUikmNAkdDEXASAsECE2vIOVKAAUUsIAXFVQCWCKiCAIKqGABrwgKXkWpgiJGEBAFKSIE6YHQQk0gs98/WJmXYSYQikyOfj9rZa3Mnn3O+Z2ZZPaZZ87sYzPGGAEAAAAAAAAAYFFeni4AAAAAAAAAAIDrQdANAAAAAAAAALA0gm4AAAAAAAAAgKURdAMAAAAAAAAALI2gGwAAAAAAAABgaQTdAAAAAAAAAABLI+gGAAAAAAAAAFgaQTcAAAAAAAAAwNIIugEAAADgGtjtdmVkZGjHjh2eLgUAAOAfj6AbAAAAAPLpwIEDevzxxxUVFSVfX1+FhoaqSpUqyszM9HRpAAAA/2iFPF0AAADAjfbee++pV69ejtt+fn669dZb1bJlSz333HMKDw/3YHUArOqPP/5Q06ZNde7cOT366KO6/fbbVahQIQUEBCgoKMjT5QEAAPyjEXQDAIC/rVGjRqls2bI6e/asvv/+e7311lv66quvtHHjRgUGBnq6PAAW069fP/n6+urHH39UqVKlPF0OAAAALkLQDQAA/rZat26tOnXqSJIefPBBlShRQmPHjtXnn3+url27erg6AFaydu1affPNN1q0aBEhNwAAQAHEHN0AAOAfo1mzZpKktLQ0SdKRI0f0n//8R9WrV1dwcLCKFCmi1q1b65dffnFZ9uzZsxoxYoRuu+02+fv7KzIyUvfcc4+2b98uSdq5c6dsNlueP02aNHGsa/ny5bLZbJo1a5aGDRumiIgIBQUFqX379tq9e7fLtletWqVWrVqpaNGiCgwMVOPGjbVy5Uq3+9ikSRO32x8xYoRL3xkzZig2NlYBAQEqXry4unTp4nb7l9u3i9ntdo0bN05Vq1aVv7+/wsPD1a9fPx09etSpX3R0tNq2beuynYEDB7qs013tY8aMcXlMJSkrK0vJycmqUKGC/Pz8VKZMGT355JPKyspy+1hd7NLHLSQkRImJidq4cWO+lq1WrZrWrl2rhg0bKiAgQGXLltXkyZOd+mVnZ2v48OGKjY1V0aJFFRQUpLi4OC1btsyp39atW9WsWTNFREQ49uOhhx7SkSNHXLbds2fPKz7fPXv2VHR0tNNyu3fvVkBAgGw2m3bu3Cnp/5/n9957z6nviBEj3D4vAwcOdKmnbdu2TtvKXeerr76ax6Pnuv7p06fLZrNp2rRpTv1eeukl2Ww2ffXVV3muS7rw95X7OHh5eSkiIkL33Xefdu3adV11/fjjj/L399f27dtVtWpV+fn5KSIiQv369XP73MyePdvx/xUSEqJ///vf2rt3r1Ofnj17Kjg4WDt27FBCQoKCgoJUsmRJjRo1SsYYl3ovfm5OnDih2NhYlS1bVvv373e0v/rqq2rYsKFKlCihgIAAxcbGas6cOU7bvd7HGAAAoCDijG4AAPCPkRtKlyhRQpK0Y8cOzZs3T506dVLZsmWVnp6uKVOmqHHjxtq0aZNKliwpScrJyVHbtm21dOlSdenSRY899phOnDihxYsXa+PGjSpfvrxjG127dlWbNm2ctjt06FC39bz44ouy2Wx66qmndPDgQY0bN07x8fFav369AgICJEnffPONWrdurdjYWCUnJ8vLy0vTp09Xs2bNtGLFCtWrV89lvaVLl1ZKSook6eTJk3r44Yfdbvu5555T586d9eCDD+rQoUOaMGGC7rrrLv38888qVqyYyzJ9+/ZVXFycJOmzzz7T3Llzne7v16+fY370Rx99VGlpaZo4caJ+/vlnrVy5Uj4+Pm4fh6tx7Ngxx75dzG63q3379vr+++/Vt29fVa5cWRs2bNDrr7+u33//XfPmzbviumNiYvTMM8/IGKPt27dr7NixatOmjVNAmpejR4+qTZs26ty5s7p27apPPvlEDz/8sHx9fdW7d29JUmZmpt555x117dpVffr00YkTJ/Tuu+8qISFBq1evVq1atSRJp06dUunSpdWuXTsVKVJEGzdu1KRJk7R371598cUXLtsOCQnR66+/7rh9//33X7He4cOH6+zZs1fs5wm9evXSZ599psGDB6tFixYqU6aMNmzYoJEjR+qBBx5w+f9yJy4uTn379pXdbtfGjRs1btw47du3TytWrLjmug4fPqyzZ8/q4YcfVrNmzfTQQw9p+/btmjRpklatWqVVq1bJz89P0v9fJ6Bu3bpKSUlRenq6xo8fr5UrV7r8f+Xk5KhVq1a644479Morr2jBggVKTk7W+fPnNWrUKLe1nDt3Tv/617+0a9curVy5UpGRkY77xo8fr/bt26t79+7Kzs7Wxx9/rE6dOunLL79UYmLiDXuMAQAAChwDAADwNzN9+nQjySxZssQcOnTI7N6923z88cemRIkSJiAgwOzZs8cYY8zZs2dNTk6O07JpaWnGz8/PjBo1ytE2bdo0I8mMHTvWZVt2u92xnCQzZswYlz5Vq1Y1jRs3dtxetmyZkWRKlSplMjMzHe2ffPKJkWTGjx/vWHfFihVNQkKCYzvGGHP69GlTtmxZ06JFC5dtNWzY0FSrVs1x+9ChQ0aSSU5OdrTt3LnTeHt7mxdffNFp2Q0bNphChQq5tG/bts1IMu+//76jLTk52Vx8KLlixQojycycOdNp2QULFri0R0VFmcTERJfaBwwYYC49PL209ieffNKEhYWZ2NhYp8f0ww8/NF5eXmbFihVOy0+ePNlIMitXrnTZ3sUaN27stD5jjBk2bJiRZA4ePHjFZSWZ1157zdGWlZVlatWqZcLCwkx2drYxxpjz58+brKwsp2WPHj1qwsPDTe/evS+7jf79+5vg4GCX9u7du5uyZcs6tV36mCUlJZmoqCjH7Y0bNxovLy/TunVrI8mkpaUZY4z5888/jSQzbdo0p/Vd+lznbmPAgAEu9SQmJjpt63L/F5db//79+03x4sVNixYtTFZWlqldu7a59dZbzfHjx/NcT66oqCiTlJTk1NatWzcTGBh4XXXl3m7evLk5f/68oz339WbChAnGGGOys7NNWFiYqVatmjlz5oyj35dffmkkmeHDhzvakpKSjCTzyCOPONrsdrtJTEw0vr6+5tChQ071Tp8+3djtdtO9e3cTGBhoVq1a5VL36dOnnW5nZ2ebatWqmWbNmjm1X89jDAAAUBAxdQkAAPjbio+PV2hoqMqUKaMuXbooODhYc+fOdcyv6+fnJy+vC4dDOTk5Onz4sIKDg1WpUiWtW7fOsZ5PP/1UISEheuSRR1y2cemUDlejR48eKly4sOP2vffeq8jISMe0AevXr9e2bdvUrVs3HT58WBkZGcrIyNCpU6fUvHlzfffdd7Lb7U7rPHv2rPz9/S+73c8++0x2u12dO3d2rDMjI0MRERGqWLGiy1Qa2dnZkuQ4W9Wd2bNnq2jRomrRooXTOmNjYxUcHOyyznPnzjn1y8jIuOIZxnv37tWECRP03HPPKTg42GX7lStXVkxMjNM6c6eruXT77uTWdOjQIaWmpmru3LmqUaOGQkJCrrhsoUKF1K9fP8dtX19f9evXTwcPHtTatWslSd7e3vL19ZV04Qz0I0eO6Pz586pTp47T31uu48ePKz09XUuXLtX8+fN11113ufTJzs6+7PPiztChQ3X77berU6dOTu2hoaGSpD179uRrPWfPnnV5Ds+dO+e27+nTp5WRkaGjR486TcmRl4iICE2aNEmLFy9WXFyc1q9fr2nTpqlIkSL5qi0rK0sZGRk6ePCgFi9erG+++UbNmze/7rokafDgwfL29nbcvv/++xUeHq758+dLktasWaODBw+qf//+Tv+LiYmJiomJcfS72MXTwOROC5Odna0lS5a49B0yZIhmzpypTz75xO03OnK/DSJd+KbB8ePHFRcX5/I3dr2PMQAAQEHD1CUAAOBva9K
"text/plain": [
"<Figure size 1800x500 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from locale import normalize\n",
"\n",
"\n",
"def analyze_balance(y_train, y_val, y_test, y_name):\n",
" print(\"Распределение классов в обучающей выборке:\")\n",
" print(y_train.value_counts(normalize=True))\n",
"\n",
" print(\"\\nРаспределение классов в контрольной выборке:\")\n",
" print(y_val.value_counts(normalize=True))\n",
"\n",
" print(\"\\nРаспределение классов в тестовой выборке:\")\n",
" print(y_test.value_counts(normalize=True))\n",
"\n",
" fig, axes = plt.subplots(1, 3, figsize=(18,5), sharey=True)\n",
" fig.suptitle('Распределение в различных выборках')\n",
"\n",
" sns.barplot(x=y_train.value_counts().index, y=y_train.value_counts(normalize=True), ax=axes[0])\n",
" axes[0].set_title('Обучающая выборка')\n",
" axes[0].set_xlabel(y_name)\n",
" axes[0].set_ylabel('Доля')\n",
"\n",
" sns.barplot(x=y_val.value_counts().index, y=y_val.value_counts(normalize=True), ax=axes[1])\n",
" axes[1].set_title('Контрольная выборка')\n",
" axes[1].set_xlabel(y_name)\n",
"\n",
" sns.barplot(x=y_test.value_counts().index, y=y_test.value_counts(normalize=True), ax=axes[2])\n",
" axes[2].set_title('Тестовая выборка')\n",
" axes[2].set_xlabel(y_name)\n",
"\n",
" plt.show()\n",
"\n",
"analyze_balance(y_train, y_val, y_test, 'stroke')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Заметим, что выборки не сбалансированы. Для балансировки будем использовать RandomOverSampler"
]
},
{
"cell_type": "code",
"execution_count": 169,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Распределение классов в обучающей выборке:\n",
"stroke\n",
"0 0.5\n",
"1 0.5\n",
"Name: proportion, dtype: float64\n",
"\n",
"Распределение классов в контрольной выборке:\n",
"stroke\n",
"0 0.5\n",
"1 0.5\n",
"Name: proportion, dtype: float64\n",
"\n",
"Распределение классов в тестовой выборке:\n",
"stroke\n",
"0 0.957599\n",
"1 0.042401\n",
"Name: proportion, dtype: float64\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABboAAAHyCAYAAAAtJXgGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABpq0lEQVR4nO3deZyNdf/H8feZMasxxKyWZmwZO40lNNmGwdjqjix3BoVCC90qKoOWSUqEQkUL3Ymiu5Q1SpoIKbIkRvZhbGOdYc7394fHnJ/jnGEsOXPV6/l4zOMx53u+13V9rnPmnO913nOd72UzxhgBAAAAAAAAAGBRXp4uAAAAAAAAAACA60HQDQAAAAAAAACwNIJuAAAAAAAAAIClEXQDAAAAAAAAACyNoBsAAAAAAAAAYGkE3QAAAAAAAAAASyPoBgAAAAAAAABYGkE3AAAAAAAAAMDSCLoBAAAA4BrY7XZlZGRox44dni4FAADgH4+gGwAAAADy6cCBA3r88ccVFRUlX19fhYaGqkqVKsrMzPR0aQAAAP9ohTxdAAAAwI323nvvqVevXo7bfn5+uvXWW9WyZUs999xzCg8P92B1AKzqjz/+UNOmTXXu3Dk9+uijuv3221WoUCEFBASocOHCni4PAADgH42gGwAA/G2NGjVKZcuW1dmzZ/X999/rrbfe0ldffaWNGzcqMDDQ0+UBsJh+/frJ19dXP/74o0qVKuXpcgAAAHARgm4AAPC31bp1a9WpU0eS9OCDD6pEiRIaO3asPv/8c3Xt2tXD1QGwkrVr1+qbb77RokWLCLkBAAAKIOboBgAA/xjNmjWTJKWlpUmSjhw5ov/85z+qXr26goKCFBwcrNatW+uXX35xWfbs2bMaMWKEbrvtNvn7+ysyMlL33HOPtm/fLknauXOnbDZbnj9NmjRxrGv58uWy2WyaNWuWhg0bpoiICBUuXFjt27fX7t27Xba9atUqtWrVSkWLFlVgYKAaN26slStXut3HJk2auN3+iBEjXPrOmDFDsbGxCggIUPHixdWlSxe327/cvl3Mbrdr3Lhxqlq1qvz9/RUeHq5+/frp6NGjTv2io6PVtm1bl+0MHDjQZZ3uah8zZozLYypJWVlZSk5OVoUKFeTn56cyZcroySefVFZWltvH6mKXPm4hISFKTEzUxo0b87VstWrVtHbtWjVs2FABAQEqW7asJk+e7NQvOztbw4cPV2xsrIoWLarChQsrLi5Oy5Ytc+q3detWNWvWTBEREY79eOihh3TkyBGXbffs2fOKz3fPnj0VHR3ttNzu3bsVEBAgm82mnTt3Svr/5/m9995z6jtixAi3z8vAgQNd6mnbtq3TtnLX+eqrr+bx6Lmuf/r06bLZbJo2bZpTv5deekk2m01fffVVnuuSLvx95T4OXl5eioiI0H333addu3ZdV10//vij/P39tX37dlWtWlV+fn6KiIhQv3793D43s2fPdry+QkJC9O9//1t79+516tOzZ08FBQVpx44dSkhIUOHChVWyZEmNGjVKxhiXei9+bk6cOKHY2FiVLVtW+/fvd7S/+uqratiwoUqUKKGAgADFxsZqzpw5Ttu93scYAACgIOKMbgAA8I+RG0qXKFFCkrRjxw7NmzdPnTp1UtmyZZWenq4pU6aocePG2rRpk0qWLClJysnJUdu2bbV06VJ16dJFjz32mE6cOKHFixdr48aNKl++vGMbXbt2VZs2bZy2O3ToULf1vPjii7LZbHrqqad08OBBjRs3TvHx8Vq/fr0CAgIkSd98841at26t2NhYJScny8vLS9OnT1ezZs20YsUK1atXz2W9pUuXVkpKiiTp5MmTevjhh91u+7nnnlPnzp314IMP6tChQ5owYYLuuusu/fzzzypWrJjLMn379lVcXJwk6bPPPtPcuXOd7u/Xr59jfvRHH31UaWlpmjhxon7++WetXLlSPj4+bh+Hq3Hs2DHHvl3Mbrerffv2+v7779W3b19VrlxZGzZs0Ouvv67ff/9d8+bNu+K6Y2Ji9Mwzz8gYo+3bt2vs2LFq06aNU0Cal6NHj6pNmzbq3Lmzunbtqk8++UQPP/ywfH191bt3b0lSZmam3nnnHXXt2lV9+vTRiRMn9O677yohIUGrV69WrVq1JEmnTp1S6dKl1a5dOwUHB2vjxo2aNGmS9u7dqy+++MJl2yEhIXr99dcdt++///4r1jt8+HCdPXv2iv08oVevXvrss880ePBgtWjRQmXKlNGGDRs0cuRIPfDAAy6vL3fi4uLUt29f2e12bdy4UePGjdO+ffu0YsWKa67r8OHDOnv2rB5++GE1a9ZMDz30kLZv365JkyZp1apVWrVqlfz8/CT9/3UC6tatq5SUFKWnp2v8+PFauXKly+srJydHrVq10h133KFXXnlFCxYsUHJyss6fP69Ro0a5reXcuXP617/+pV27dmnlypWKjIx03Dd+/Hi1b99e3bt3V3Z2tj7++GN16tRJX375pRITE2/YYwwAAFDgGAAAgL+Z6dOnG0lmyZIl5tChQ2b37t3m448/NiVKlDABAQFmz549xhhjzp49a3JycpyWTUtLM35+fmbUqFGOtmnTphlJZuzYsS7bstvtjuUkmTFjxrj0qVq1qmncuLHj9rJly4wkU6pUKZOZmelo/+STT4wkM378eMe6K1asaBISEhzbMcaY06dPm7Jly5oWLVq4bKthw4amWrVqjtuHDh0ykkxycrKjbefOncbb29u8+OKLTstu2LDBFCpUyKV927ZtRpJ5//33HW3Jycnm4kPJFStWGElm5syZTssuWLDApT0qKsokJia61D5gwABz6eHppbU/+eSTJiwszMTGxjo9ph9++KHx8vIyK1ascFp+8uTJRpJZuXKly/Yu1rhxY6f1GWPMsGHDjCRz8ODBKy4rybz22muOtqysLFOrVi0TFhZmsrOzjTHGnD9/3mRlZTkte/ToURMeHm569+592W3079/fBAUFubR3797dlC1b1qnt0scsKSnJREVFOW5v3LjReHl5mdatWxtJJi0tzRhjzJ9//mkkmWnTpjmt79LnOncbAwYMcKknMTHRaVuXe11cbv379+83xYsXNy1atDBZWVmmdu3a5tZbbzXHjx/Pcz25oqKiTFJSklNbt27dTGBg4HXVlXu7efPm5vz584723PebCRMmGGOMyc7ONmFhYaZatWrmzJkzjn5ffvmlkWSGDx/uaEtKSjKSzCOPPOJos9vtJjEx0fj6+ppDhw451Tt9+nRjt9tN9+7dTWBgoFm1apVL3adPn3a6nZ2dbapVq2aaNWvm1H49jzEAAEBBxNQlAADgbys+Pl6hoaEqU6aMunTpoqCgIM2dO9cxv66fn5+8vC4cDuXk5Ojw4cMKCgpSpUqVtG7dOsd6Pv30U4WEhOiRRx5x2calUzpcjR49eqhIkSKO2/fee68iIyMd0wasX79e27ZtU7du3XT48GFlZGQoIyNDp06dUvPmzfXdd9/Jbrc7rfPs2bPy9/e/7HY/++wz2e12de7c2bHOjIwMRUREqGLFii5TaWRnZ0uS42xVd2bPnq2iRYuqRYsWTuuMjY1VUFCQyzrPnTvn1C8jI+OKZxjv3btXEyZM0HPPPaegoCCX7VeuXFkxMTFO68ydrubS7buTW9OhQ4eUmpqquXPnqkaNGgoJCbnisoUKFVK/fv0ct319fdWvXz8dPHhQa9eulSR5e3vL19dX0oUz0I8cOaLz58+rTp06Tn9vuY4fP6709HQtXbpU8+fP11133eXSJzs7+7LPiztDhw7V7bffrk6dOjm1h4aGSpL27NmTr/WcPXvW5Tk8d+6c276nT59WRkaGjh496jQlR14iIiI0adIkLV68WHFxcVq/fr2mTZum4ODgfNWWlZWljIwMHTx4UIsXL9Y333yj5s2bX3ddkjR48GB5e3s7bt9///0KDw/X/PnzJUlr1qzRwYMH1b9/f6fXYmJiomJiYhz9LnbxNDC508JkZ2dryZIlLn2HDBmimTNn6pNPPnH7jY7cb4NIF75pcPz4ccXFxbn8jV3vYwwAAFDQMHUJAAD425o
"text/plain": [
"<Figure size 1800x500 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"randoversamp = RandomOverSampler(random_state=42)\n",
"\n",
"# Применение RandomOverSampler для балансировки выборок\n",
"X_train_resampled, y_train_resampled = randoversamp.fit_resample(X_train, y_train)\n",
"X_val_resampled, y_val_resampled = randoversamp.fit_resample(X_val, y_val)\n",
"\n",
"# Проверка сбалансированности после RandomOverSampler\n",
"analyze_balance(y_train_resampled, y_val_resampled, y_test, \"stroke\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Выборки сбалансированы"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Применим унитарное кодирование категориальных признаков (one-hot encoding), переведя их в бинарные вектора."
]
},
{
"cell_type": "code",
"execution_count": 170,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" age hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"0 53.0 0 0 113.21 28.6 True \n",
"1 62.0 0 0 88.63 24.5 False \n",
"2 17.0 0 0 83.23 28.1 False \n",
"3 77.0 1 0 176.71 33.2 False \n",
"4 7.0 0 0 62.08 16.1 True \n",
"\n",
" gender_Other ever_married_Yes work_type_Never_worked work_type_Private \\\n",
"0 False True False True \n",
"1 False True False False \n",
"2 False False False True \n",
"3 False True False False \n",
"4 False False False False \n",
"\n",
" work_type_Self-employed work_type_children Residence_type_Urban \\\n",
"0 False False True \n",
"1 False False True \n",
"2 False False False \n",
"3 True False False \n",
"4 False True False \n",
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"0 False False \n",
"1 False True \n",
"2 False True \n",
"3 False True \n",
"4 False False \n",
"\n",
" smoking_status_smokes \n",
"0 True \n",
"1 False \n",
"2 False \n",
"3 False \n",
"4 False \n"
]
}
],
"source": [
"# Определение категориальных признаков\n",
"categorical_features = [\n",
" \"gender\",\n",
" \"ever_married\",\n",
" \"work_type\",\n",
" \"Residence_type\",\n",
" \"smoking_status\",\n",
"]\n",
"\n",
"# Применение one-hot encoding к обучающей выборке\n",
"X_train_encoded = pd.get_dummies(\n",
" X_train_resampled, columns=categorical_features, drop_first=True\n",
")\n",
"\n",
"# Применение one-hot encoding к контрольной выборке\n",
"X_val_encoded = pd.get_dummies(\n",
" X_val_resampled, columns=categorical_features, drop_first=True\n",
")\n",
"\n",
"# Применение one-hot encoding к тестовой выборке\n",
"X_test_encoded = pd.get_dummies(X_test, columns=categorical_features, drop_first=True)\n",
"\n",
"print(X_train_encoded.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Перейдем к числовым признакам, а именно к колонке age, применим дискретизацию (позволяет преобразовать данные из числового представления в категориальное):"
]
},
{
"cell_type": "code",
"execution_count": 171,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"0 0 0 113.21 28.6 True \n",
"1 0 0 88.63 24.5 False \n",
"2 0 0 83.23 28.1 False \n",
"3 1 0 176.71 33.2 False \n",
"4 0 0 62.08 16.1 True \n",
"\n",
" gender_Other ever_married_Yes work_type_Never_worked work_type_Private \\\n",
"0 False True False True \n",
"1 False True False False \n",
"2 False False False True \n",
"3 False True False False \n",
"4 False False False False \n",
"\n",
" work_type_Self-employed work_type_children Residence_type_Urban \\\n",
"0 False False True \n",
"1 False False True \n",
"2 False False False \n",
"3 True False False \n",
"4 False True False \n",
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"0 False False \n",
"1 False True \n",
"2 False True \n",
"3 False True \n",
"4 False False \n",
"\n",
" smoking_status_smokes age_bin \n",
"0 True middle-aged \n",
"1 False old \n",
"2 False young \n",
"3 False old \n",
"4 False young \n"
]
}
],
"source": [
"# Определение числовых признаков для дискретизации\n",
"numerical_features = [\"age\"]\n",
"\n",
"\n",
"# Функция для дискретизации числовых признаков\n",
"def discretize_features(df, features, bins, labels):\n",
" for feature in features:\n",
" df[f\"{feature}_bin\"] = pd.cut(df[feature], bins=bins, labels=labels)\n",
" df.drop(columns=[feature], inplace=True)\n",
" return df\n",
"\n",
"\n",
"# Заданные интервалы и метки\n",
"age_bins = [0, 25, 55, 100]\n",
"age_labels = [\"young\", \"middle-aged\", \"old\"]\n",
"\n",
"# Применение дискретизации к обучающей, контрольной и тестовой выборкам\n",
"X_train_encoded = discretize_features(\n",
" X_train_encoded, numerical_features, bins=age_bins, labels=age_labels\n",
")\n",
"X_val_encoded = discretize_features(\n",
" X_val_encoded, numerical_features, bins=age_bins, labels=age_labels\n",
")\n",
"X_test_encoded = discretize_features(\n",
" X_test_encoded, numerical_features, bins=age_bins, labels=age_labels\n",
")\n",
"\n",
"print(X_train_encoded.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Применим ручной синтез признаков. Например, в этом случае создадим признак, в котором вычисляется отклонение уровня глюкозы от среднего для определенной возрастной группы. Вышеуказанный признак может быть полезен для определения пациентов с аномальными данными."
]
},
{
"cell_type": "code",
"execution_count": 172,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"0 0 0 113.21 28.6 True \n",
"1 0 0 88.63 24.5 False \n",
"2 0 0 83.23 28.1 False \n",
"3 1 0 176.71 33.2 False \n",
"4 0 0 62.08 16.1 True \n",
"\n",
" gender_Other ever_married_Yes work_type_Never_worked work_type_Private \\\n",
"0 False True False True \n",
"1 False True False False \n",
"2 False False False True \n",
"3 False True False False \n",
"4 False False False False \n",
"\n",
" work_type_Self-employed work_type_children Residence_type_Urban \\\n",
"0 False False True \n",
"1 False False True \n",
"2 False False False \n",
"3 True False False \n",
"4 False True False \n",
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"0 False False \n",
"1 False True \n",
"2 False True \n",
"3 False True \n",
"4 False False \n",
"\n",
" smoking_status_smokes age_bin glucose_age_deviation \n",
"0 True middle-aged 10.186796 \n",
"1 False old -46.562537 \n",
"2 False young -10.882496 \n",
"3 False old 41.517463 \n",
"4 False young -32.032496 \n"
]
}
],
"source": [
"age_glucose_mean = X_train_encoded.groupby(\"age_bin\", observed=False)[\n",
" \"avg_glucose_level\"\n",
"].transform(\"mean\")\n",
"X_train_encoded[\"glucose_age_deviation\"] = (\n",
" X_train_encoded[\"avg_glucose_level\"] - age_glucose_mean\n",
")\n",
"\n",
"age_glucose_mean = X_val_encoded.groupby(\"age_bin\", observed=False)[\n",
" \"avg_glucose_level\"\n",
"].transform(\"mean\")\n",
"X_val_encoded[\"glucose_age_deviation\"] = (\n",
" X_val_encoded[\"avg_glucose_level\"] - age_glucose_mean\n",
")\n",
"\n",
"age_glucose_mean = X_test_encoded.groupby(\"age_bin\", observed=False)[\n",
" \"avg_glucose_level\"\n",
"].transform(\"mean\")\n",
"X_test_encoded[\"glucose_age_deviation\"] = (\n",
" X_test_encoded[\"avg_glucose_level\"] - age_glucose_mean\n",
")\n",
"\n",
"print(X_train_encoded.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Используем масштабирование признаков, для приведения всех числовых признаков к одинаковым или очень похожим диапазонам значений/распределениям. \n",
"### Масштабирование признаков позволяет получить более качественную модель за счет снижения доминирования одних признаков над другими."
]
},
{
"cell_type": "code",
"execution_count": 173,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"0 0 0 -0.120278 -0.174711 True \n",
"1 0 0 -0.561737 -0.757960 False \n",
"2 0 0 -0.658721 -0.245839 False \n",
"3 1 0 1.020189 0.479666 False \n",
"4 0 0 -1.038577 -1.952910 True \n",
"\n",
" gender_Other ever_married_Yes work_type_Never_worked work_type_Private \\\n",
"0 False True False True \n",
"1 False True False False \n",
"2 False False False True \n",
"3 False True False False \n",
"4 False False False False \n",
"\n",
" work_type_Self-employed work_type_children Residence_type_Urban \\\n",
"0 False False True \n",
"1 False False True \n",
"2 False False False \n",
"3 True False False \n",
"4 False True False \n",
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"0 False False \n",
"1 False True \n",
"2 False True \n",
"3 False True \n",
"4 False False \n",
"\n",
" smoking_status_smokes age_bin glucose_age_deviation \n",
"0 True middle-aged 0.192712 \n",
"1 False old -0.880860 \n",
"2 False young -0.205873 \n",
"3 False old 0.785418 \n",
"4 False young -0.605984 \n"
]
}
],
"source": [
"numerical_features = [\"avg_glucose_level\", \"bmi\", \"glucose_age_deviation\"]\n",
"\n",
"scaler = StandardScaler()\n",
"X_train_encoded[numerical_features] = scaler.fit_transform(\n",
" X_train_encoded[numerical_features]\n",
")\n",
"X_val_encoded[numerical_features] = scaler.transform(X_val_encoded[numerical_features])\n",
"X_test_encoded[numerical_features] = scaler.transform(\n",
" X_test_encoded[numerical_features]\n",
")\n",
"\n",
"print(X_train_encoded.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Сконструируем признаки, используя фреймворк Featuretools:"
]
},
{
"cell_type": "code",
"execution_count": 178,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" hypertension heart_disease avg_glucose_level bmi gender_Male \\\n",
"index \n",
"0 0 0 -0.120278 -0.174711 True \n",
"1 0 0 -0.561737 -0.757960 False \n",
"2 0 0 -0.658721 -0.245839 False \n",
"3 1 0 1.020189 0.479666 False \n",
"4 0 0 -1.038577 -1.952910 True \n",
"\n",
" gender_Other ever_married_Yes work_type_Never_worked \\\n",
"index \n",
"0 False True False \n",
"1 False True False \n",
"2 False False False \n",
"3 False True False \n",
"4 False False False \n",
"\n",
" work_type_Private work_type_Self-employed work_type_children \\\n",
"index \n",
"0 True False False \n",
"1 False False False \n",
"2 True False False \n",
"3 False True False \n",
"4 False False True \n",
"\n",
" Residence_type_Urban smoking_status_formerly smoked \\\n",
"index \n",
"0 True False \n",
"1 True False \n",
"2 False False \n",
"3 False False \n",
"4 False False \n",
"\n",
" smoking_status_never smoked smoking_status_smokes \\\n",
"index \n",
"0 False True \n",
"1 True False \n",
"2 True False \n",
"3 True False \n",
"4 False False \n",
"\n",
" glucose_age_deviation age_bin_middle-aged age_bin_old \n",
"index \n",
"0 0.192712 True False \n",
"1 -0.880860 False True \n",
"2 -0.205873 False False \n",
"3 0.785418 False True \n",
"4 -0.605984 False False \n"
]
}
],
"source": [
"data = X_train_encoded.copy()\n",
"\n",
"es = ft.EntitySet(id=\"patients\")\n",
"\n",
"es = es.add_dataframe(\n",
" dataframe_name=\"strokes_data\", dataframe=data, index=\"index\", make_index=True\n",
")\n",
"\n",
"feature_matrix, feature_defs = ft.dfs(\n",
" entityset=es, target_dataframe_name=\"strokes_data\", max_depth=1\n",
")\n",
"\n",
"print(feature_matrix.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Оценим качество набора признаков.\n",
"\n",
"1. Предсказательная способность (для задачи классификации)\n",
" - Метрики: Accuracy, Precision, Recall, F1-Score, ROC AUC\n",
" - Методы: Обучение модели на обучающей выборке и оценка на валидационной и тестовой выборках.\n",
"\n",
"2. Вычислительная эффективность\n",
" - Методы: Измерение времени, затраченного на генерацию признаков и обучение модели.\n",
"\n",
"3. Надежность\n",
" - Методы: Кросс-валидация и анализ чувствительности модели к изменениям в данных.\n",
"\n",
"4. Корреляция\n",
" - Методы: Анализ корреляционной матрицы признаков и исключение мультиколлинеарных признаков.\n",
"\n",
"5. Логическая согласованность\n",
" - Методы: Проверка логической связи признаков с целевой переменной и интерпретация результатов модели."
]
},
{
"cell_type": "code",
"execution_count": 175,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Время обучения модели: 0.43 секунд\n"
]
}
],
"source": [
"X_train_encoded = pd.get_dummies(X_train_encoded, drop_first=True)\n",
"X_val_encoded = pd.get_dummies(X_val_encoded, drop_first=True)\n",
"X_test_encoded = pd.get_dummies(X_test_encoded, drop_first=True)\n",
"\n",
"all_columns = X_train_encoded.columns\n",
"X_train_encoded = X_train_encoded.reindex(columns=all_columns, fill_value=0)\n",
"X_val_encoded = X_val_encoded.reindex(columns=all_columns, fill_value=0)\n",
"X_test_encoded = X_test_encoded.reindex(columns=all_columns, fill_value=0)\n",
"\n",
"# Выбор модели\n",
"model = RandomForestClassifier(n_estimators=100, random_state=42)\n",
"\n",
"# Начинаем отсчет времени\n",
"start_time = time.time()\n",
"model.fit(X_train_encoded, y_train_resampled)\n",
"\n",
"# Время обучения модели\n",
"train_time = time.time() - start_time\n",
"\n",
"print(f\"Время обучения модели: {train_time:.2f} секунд\")"
]
},
{
"cell_type": "code",
"execution_count": 176,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Feature Importance:\n",
" feature importance\n",
"2 avg_glucose_level 0.195252\n",
"3 bmi 0.184431\n",
"15 glucose_age_deviation 0.181013\n",
"17 age_bin_old 0.168510\n",
"16 age_bin_middle-aged 0.031286\n",
"0 hypertension 0.026751\n",
"6 ever_married_Yes 0.026492\n",
"11 Residence_type_Urban 0.025740\n",
"4 gender_Male 0.024989\n",
"9 work_type_Self-employed 0.022764\n",
"1 heart_disease 0.021314\n",
"8 work_type_Private 0.020773\n",
"13 smoking_status_never smoked 0.019126\n",
"12 smoking_status_formerly smoked 0.017622\n",
"10 work_type_children 0.017389\n",
"14 smoking_status_smokes 0.016418\n",
"7 work_type_Never_worked 0.000127\n",
"5 gender_Other 0.000003\n"
]
}
],
"source": [
"# Получение важности признаков\n",
"importances = model.feature_importances_\n",
"feature_names = X_train_encoded.columns\n",
"\n",
"# Сортировка признаков по важности\n",
"feature_importance = pd.DataFrame({\"feature\": feature_names, \"importance\": importances})\n",
"feature_importance = feature_importance.sort_values(by=\"importance\", ascending=False)\n",
"\n",
"print(\"Feature Importance:\")\n",
"print(feature_importance)"
]
},
{
"cell_type": "code",
"execution_count": 177,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.949119373776908\n",
"Precision: 0.15789473684210525\n",
"Recall: 0.046153846153846156\n",
"F1 Score: 0.07142857142857142\n",
"ROC AUC: 0.5176273317962691\n",
"Cross-validated Accuracy: 0.991578947368421\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABCEAAAIjCAYAAAA9agHPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAADfcklEQVR4nOzdeVSV1fv38fcBFJBRFBUUBRMR53lMJYcc0hxK0zBERbMip0yzcsA50xwyszRBy6EyNXMqtTBFcxZNSRFFrDD6OoBoDsB5/vDx/nUCFRAh7fNa66zF2Xvfe1/3fcjVudiDyWw2mxERERERERERecCsCjoAEREREREREflvUBJCRERERERERPKFkhAiIiIiIiIiki+UhBARERERERGRfKEkhIiIiIiIiIjkCyUhRERERERERCRfKAkhIiIiIiIiIvlCSQgRERERERERyRdKQoiIiIiIiIhIvlASQkRERERERETyhZIQIiIi8p8WERGByWTK8vXGG288kDF37tzJuHHjuHTp0gPp/37cfh779u0r6FBybd68eURERBR0GCIikgWbgg5ARERE5N9g/Pjx+Pj4WJRVrVr1gYy1c+dOwsLCCA4OxtXV9YGM8V82b948ihcvTnBwcEGHIiIi/6AkhIiIiAjQrl076tatW9Bh3JcrV67g4OBQ0GEUmKtXr1KkSJGCDkNERO5CyzFEREREsmHjxo00bdoUBwcHnJyceOqppzh69KhFm8OHDxMcHEz58uWxs7OjVKlS9O3bl/Pnzxttxo0bx+uvvw6Aj4+PsfQjPj6e+Ph4TCZTlksJTCYT48aNs+jHZDJx7Ngxnn/+eYoWLcrjjz9u1H/22WfUqVMHe3t73Nzc6NGjB2fPns3VvQcHB+Po6EhCQgIdOnTA0dGR0qVL88EHHwBw5MgRWrRogYODA+XKlWPZsmUW199e4vHjjz/y4osvUqxYMZydnQkKCuLixYuZxps3bx5VqlTB1tYWT09PXnnllUxLVwICAqhatSr79++nWbNmFClShDfffBNvb2+OHj3Ktm3bjGcbEBAAwIULFxg+fDjVqlXD0dERZ2dn2rVrR3R0tEXfkZGRmEwmvvjiCyZNmkSZMmWws7OjZcuWnDx5MlO8u3fvpn379hQtWhQHBweqV6/O7NmzLdr88ssvPPvss7i5uWFnZ0fdunVZu3ZtTj8KEZGHnmZCiIiIiADJycn873//sygrXrw4AJ9++im9e/emTZs2vPPOO1y9epUPP/yQxx9/nIMHD+Lt7Q3A5s2bOXXqFH369KFUqVIcPXqUjz/+mKNHj/LTTz9hMpno2rUrJ06cYPny5cycOdMYw93dnT///DPHcXfr1g1fX18mT56M2WwGYNKkSYwePZru3bsTEhLCn3/+yfvvv0+zZs04ePBgrpaApKen065dO5o1a8a0adNYunQpoaGhODg48NZbbxEYGEjXrl2ZP38+QUFBNGrUKNPyltDQUFxdXRk3bhzHjx/nww8/5MyZM8aXfriVXAkLC6NVq1a89NJLRru9e/cSFRVFoUKFjP7Onz9Pu3bt6NGjB7169aJkyZIEBATw6quv4ujoyFtvvQVAyZIlATh16hRr1qyhW7du+Pj48Mcff/DRRx/RvHlzjh07hqenp0W8U6dOxcrKiuHDh5OcnMy0adMIDAxk9+7dRpvNmzfToUMHPDw8GDx4MKVKlSImJoZ169YxePBgAI4ePUqTJk0oXbo0b7zxBg4ODnzxxRd07tyZr776ii5duuT48xAReWiZRURERP7DwsPDzUCWL7PZbL58+bLZ1dXV3L9/f4vrzp07Z3ZxcbEov3r1aqb+ly9fbgbMP/74o1H27rvvmgHz6dOnLdqePn3aDJjDw8Mz9QOYx44da7wfO3asGTD37NnTol18fLzZ2traPGnSJIvyI0eOmG1sbDKV3+l57N271yjr3bu3GTBPnjzZKLt48aLZ3t7ebDKZzCtWrDDKf/nll0yx3u6zTp065hs3bhjl06ZNMwPmr7/+2mw2m81JSUnmwoULm5988klzenq60W7u3LlmwLxo0SKjrHnz5mbAPH/+/Ez3UKVKFXPz5s0zlV+7ds2iX7P51jO3tbU1jx8/3ij74YcfzIDZ39/ffP36daN89uzZZsB85MgRs9lsNqelpZl9fHzM5cqVM1+8eNGi34yMDOPnli1bmqtVq2a+du2aRX3jxo3Nvr6+meIUEXmUaTmGiIiICPDBBx+wefNmixfc+kv3pUuX6NmzJ//73/+Ml7W1NQ0aNOCHH34w+rC3tzd+vnbtGv/73/9o2LAhAAcOHHggcQ8cONDi/apVq8jIyKB79+4W8ZYqVQpfX1+LeHMqJCTE+NnV1RU/Pz8cHBzo3r27Ue7n54erqyunTp3KdP2AAQMsZjK89NJL2NjYsGHDBgC2bNnCjRs3GDJkCFZW//e/qf3798fZ2Zn169db9Gdra0ufPn2yHb+tra3Rb3p6OufPn8fR0RE/P78sP58+ffpQuHBh433Tpk0BjHs7ePAgp0+fZsiQIZlml9ye2XHhwgW+//57unfvzuXLl43P4/z587Rp04bY2Fh+++23bN+DiMjDTssxRERERID69etnuTFlbGwsAC1atMjyOmdnZ+PnCxcuEBYWxooVK0hKSrJol5ycnIfR/p9/LnmIjY3FbDbj6+ubZfu/JwFyws7ODnd3d4syFxcXypQpY3zh/nt5Vns9/DMmR0dHPDw8iI+PB+DMmTPArUTG3xUuXJjy5csb9beVLl3aIklwLxkZGcyePZt58+Zx+vRp0tPTjbpixYplal+2bFmL90WLFgUw7i0uLg64+ykqJ0+exGw2M3r0aEaPHp1lm6SkJEqXLp3t+xAReZgpCSEiIiJyFxkZGcCtfSFKlSqVqd7G5v/+d6p79+7s3LmT119/nZo1a+Lo6EhGRgZt27Y1+rmbf36Zv+3vX5b/6e+zL27HazKZ2LhxI9bW1pnaOzo63jOOrGTV193Kzf9/f4oH6Z/3fi+TJ09m9OjR9O3blwkTJuDm5oaVlRVDhgzJ8vPJi3u73e/w4cNp06ZNlm0qVKiQ7f5ERB52SkKIiIiI3MVjjz0GQIkSJWjVqtUd2128eJGtW7cSFhbGmDFjjPLbMyn+7k7Jhtt/af/nSRD/nAFwr3jNZjM+Pj5UrFgx29flh9jYWJ544gnjfWpqKomJibRv3x6AcuXKAXD8+HHKly9vtLtx4wanT5++6/P/uzs935UrV/LEE0/wySefWJRfunTJ2CA0J27/bvz88893jO32fRQqVCjb8YuIPMq0J4SIiIjIXbRp0wZnZ2cmT57MzZs3M9XfPtHi9l/N//lX8lmzZmW6xsHBAcicbHB2dqZ48eL8+OOPFuXz5s3Ldrxdu3bF2tqasLCwTLGYzWaL40Lz28cff2zxDD/88EPS0tJo164dAK1ataJw4cLMmTPHIvZPPvmE5ORknnrqqWyN4+DgkOnZwq3P6J/P5Msvv8z1ngy1a9fGx8eHWbNmZRrv9jglSpQgICCAjz76iMTExEx95OZEFBGRh5lmQoiIiIjchbOzMx9++CEvvPACtWvXpkePHri7u5OQkMD69etp0qQJc+fOxdnZ2Ti+8ubNm5QuXZrvvvuO06dPZ+qzTp06ALz11lv06NGDQoUK0bFjRxwcHAgJCWHq1KmEhIRQt25dfvzxR06cOJHteB977DEmTpzIqFGjiI+Pp3Pnzjg5OXH69GlWr17NgAEDGD58eJ49n5y4ceMGLVu2pHv37hw/fpx58+bx+OOP8/TTTwO3jikdNWoUYWFhtG3blqefftpoV69ePXr16pWtcerUqcOHH37IxIkTqVChAiVKlKBFixZ06NCB8ePH06dPHxo3bsyRI0dYunSpxayLnLCysuLDDz+kY8eO1KxZkz59+uDh4cEvv/zC0aNH+fbbb4Fbm54+/vjjVKtWjf79+1O+fHn++OMPdu3axa+//kp0dHSuxhcReRgpCSEiIiJyD88//zyenp5MnTqVd999l+vXr1O6dGmaNm1qcTrDsmXLePXVV/nggw8wm808+eSTbNy4EU9
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Accuracy: 1.0\n",
"Train Precision: 1.0\n",
"Train Recall: 1.0\n",
"Train F1 Score: 1.0\n",
"Train ROC AUC: 1.0\n"
]
}
],
"source": [
"# Предсказание и оценка\n",
"y_pred = model.predict(X_test_encoded)\n",
"\n",
"accuracy = accuracy_score(y_test, y_pred)\n",
"precision = precision_score(y_test, y_pred)\n",
"recall = recall_score(y_test, y_pred)\n",
"f1 = f1_score(y_test, y_pred)\n",
"roc_auc = roc_auc_score(y_test, y_pred)\n",
"\n",
"print(f\"Accuracy: {accuracy}\")\n",
"print(f\"Precision: {precision}\")\n",
"print(f\"Recall: {recall}\")\n",
"print(f\"F1 Score: {f1}\")\n",
"print(f\"ROC AUC: {roc_auc}\")\n",
"\n",
"# Кросс-валидация\n",
"scores = cross_val_score(\n",
" model, X_train_encoded, y_train_resampled, cv=5, scoring=\"accuracy\"\n",
")\n",
"accuracy_cv = scores.mean()\n",
"print(f\"Cross-validated Accuracy: {accuracy_cv}\")\n",
"\n",
"# Анализ важности признаков\n",
"feature_importances = model.feature_importances_\n",
"feature_names = X_train_encoded.columns\n",
"\n",
"importance_df = pd.DataFrame(\n",
" {\"Feature\": feature_names, \"Importance\": feature_importances}\n",
")\n",
"importance_df = importance_df.sort_values(by=\"Importance\", ascending=False)\n",
"\n",
"plt.figure(figsize=(10, 6))\n",
"sns.barplot(x=\"Importance\", y=\"Feature\", data=importance_df)\n",
"plt.title(\"Feature Importance\")\n",
"plt.show()\n",
"\n",
"# Проверка на переобучение\n",
"y_train_pred = model.predict(X_train_encoded)\n",
"\n",
"accuracy_train = accuracy_score(y_train_resampled, y_train_pred)\n",
"precision_train = precision_score(y_train_resampled, y_train_pred)\n",
"recall_train = recall_score(y_train_resampled, y_train_pred)\n",
"f1_train = f1_score(y_train_resampled, y_train_pred)\n",
"roc_auc_train = roc_auc_score(y_train_resampled, y_train_pred)\n",
"\n",
"print(f\"Train Accuracy: {accuracy_train}\")\n",
"print(f\"Train Precision: {precision_train}\")\n",
"print(f\"Train Recall: {recall_train}\")\n",
"print(f\"Train F1 Score: {f1_train}\")\n",
"print(f\"Train ROC AUC: {roc_auc_train}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}