AIM-PIbd-31-Rodionov-I-A/lab_4/lab4.ipynb

3437 lines
520 KiB
Plaintext
Raw Permalink Normal View History

2024-11-15 23:54:15 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Данные по инсультам\n",
"\n",
"Выведем информацию о столбцах датасета:"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>9046</th>\n",
" <td>Male</td>\n",
" <td>67.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>228.69</td>\n",
" <td>36.6</td>\n",
" <td>formerly smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>51676</th>\n",
" <td>Female</td>\n",
" <td>61.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>202.21</td>\n",
" <td>NaN</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31112</th>\n",
" <td>Male</td>\n",
" <td>80.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>105.92</td>\n",
" <td>32.5</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>60182</th>\n",
" <td>Female</td>\n",
" <td>49.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>171.23</td>\n",
" <td>34.4</td>\n",
" <td>smokes</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1665</th>\n",
" <td>Female</td>\n",
" <td>79.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>174.12</td>\n",
" <td>24.0</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18234</th>\n",
" <td>Female</td>\n",
" <td>80.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>83.75</td>\n",
" <td>NaN</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>44873</th>\n",
" <td>Female</td>\n",
" <td>81.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>125.20</td>\n",
" <td>40.0</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19723</th>\n",
" <td>Female</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>82.99</td>\n",
" <td>30.6</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>37544</th>\n",
" <td>Male</td>\n",
" <td>51.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>166.29</td>\n",
" <td>25.6</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>44679</th>\n",
" <td>Female</td>\n",
" <td>44.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Govt_job</td>\n",
" <td>Urban</td>\n",
" <td>85.28</td>\n",
" <td>26.2</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5110 rows × 11 columns</p>\n",
"</div>"
],
"text/plain": [
" gender age hypertension heart_disease ever_married work_type \\\n",
"id \n",
"9046 Male 67.0 0 1 Yes Private \n",
"51676 Female 61.0 0 0 Yes Self-employed \n",
"31112 Male 80.0 0 1 Yes Private \n",
"60182 Female 49.0 0 0 Yes Private \n",
"1665 Female 79.0 1 0 Yes Self-employed \n",
"... ... ... ... ... ... ... \n",
"18234 Female 80.0 1 0 Yes Private \n",
"44873 Female 81.0 0 0 Yes Self-employed \n",
"19723 Female 35.0 0 0 Yes Self-employed \n",
"37544 Male 51.0 0 0 Yes Private \n",
"44679 Female 44.0 0 0 Yes Govt_job \n",
"\n",
" Residence_type avg_glucose_level bmi smoking_status stroke \n",
"id \n",
"9046 Urban 228.69 36.6 formerly smoked 1 \n",
"51676 Rural 202.21 NaN never smoked 1 \n",
"31112 Rural 105.92 32.5 never smoked 1 \n",
"60182 Urban 171.23 34.4 smokes 1 \n",
"1665 Rural 174.12 24.0 never smoked 1 \n",
"... ... ... ... ... ... \n",
"18234 Urban 83.75 NaN never smoked 0 \n",
"44873 Urban 125.20 40.0 never smoked 0 \n",
"19723 Rural 82.99 30.6 never smoked 0 \n",
"37544 Rural 166.29 25.6 formerly smoked 0 \n",
"44679 Urban 85.28 26.2 Unknown 0 \n",
"\n",
"[5110 rows x 11 columns]"
]
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"random_state=9\n",
"\n",
"df = pd.read_csv(\"..//..//static//csv//healthcare-dataset-stroke-data.csv\", index_col=\"id\")\n",
"\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Бизнес-цели\n",
"\n",
"### Классификация\n",
"\n",
"Цель: разработать модель классификации, которая сможет предсказать возможность возникновения инсульта у человека на основе социально-демографических факторов, состояния здоровья и образа жизни.\n",
"\n",
"Применение:\n",
"\n",
"1. Медицинские учреждения: модель может использоваться для раннего выявления пациентов с высоким риском инсульта, что позволит предпринять профилактические меры и уменьшить вероятность серьезных последствий.\n",
"2. Системы поддержки принятия медицинских решений: модель может быть встроена в электронные медицинские карты для автоматического предупреждения врачей о пациентах, находящихся в зоне повышенного риска.\n",
"3. Образовательные программы: модель может помочь повысить осведомленность населения о факторах риска инсульта и способах их снижения, что также может улучшить профилактику заболеваний.\n",
"\n",
"### Регрессия\n",
"\n",
"Цель: разработать модель регрессии для прогнозирования уровня глюкозы в крови человека на основе социально-демографических факторов, состояния здоровья и образа жизни. Модель позволит определить тенденцию к повышению или снижению уровня глюкозы и, в дальнейшем, оценить возможные риски, связанные с состоянием пациента.\n",
"\n",
"Применение:\n",
"\n",
"1. Медицинские учреждения: помощь в раннем выявлении пациентов с потенциально высоким уровнем глюкозы для контроля и назначения профилактических мер, снижающих риск диабета и других осложнений.\n",
"2. Системы поддержки принятия медицинских решений: интеграция модели в медицинские записи позволит врачам получать оценку уровня глюкозы, что упростит мониторинг и ведение пациентов, особенно при отсутствии лабораторных данных в реальном времени.\n",
"3. Образовательные программы и общественное здравоохранение: с помощью модели можно повысить осведомленность населения о факторах, влияющих на уровень глюкозы, и предлагать рекомендации по улучшению образа жизни для поддержания нормального уровня глюкозы.\n",
"\n",
"## Достижимый уровень качества модели\n",
"\n",
"Модель классификации для предсказания инсульта на основе этого набора данных может достигнуть хорошего качества, но с некоторыми ограничениями. \n",
"\n",
"- Информативные признаки: Датасет содержит признаки, которые являются важными факторами риска инсульта (например, возраст, наличие гипертонии и сердечных заболеваний). Эти данные, скорее всего, дают модели достаточно информации для распознавания группы повышенного риска.\n",
"\n",
"- Ограничения данных: Несмотря на наличие ключевых медицинских факторов, датасет не включает генетические данные, специфическую историю заболеваний или детализированные данные о питании и физической активности, которые также влияют на риск инсульта. Это может ограничить максимальное качество модели.\n",
"\n",
"Для задачи регрессии по предсказанию уровня глюкозы в крови модель может также достигнуть хорошего качества с ограниченной точностью.\n",
"\n",
"- Информативность признаков: Данные содержат признаки, которые связаны с уровнем глюкозы (например, возраст, курение, наличие гипертонии), и их можно использовать для создания прогноза на общем уровне.\n",
"\n",
"- Недостающие факторы: Уровень глюкозы сильно зависит от питания, уровня физической активности, гормональных изменений, которые не представлены в данных. Из-за этого модель будет иметь ограниченную точность при оценке этого параметра."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Классификация\n",
"\n",
"Разделим набор данных на на обучающую и тестовые выборки (80/20). Целевой признак - stroke"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
"from typing import Tuple\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \"\"\"\n",
" Splits a Pandas dataframe into three subsets (train, val, and test)\n",
" following fractional ratios provided by the user, where each subset is\n",
" stratified by the values in a specific column (that is, each subset has\n",
" the same relative frequency of the values in the column). It performs this\n",
" splitting by running train_test_split() twice.\n",
"\n",
" Parameters\n",
" ----------\n",
" df_input : Pandas dataframe\n",
" Input dataframe to be split.\n",
" stratify_colname : str\n",
" The name of the column that will be used for stratification. Usually\n",
" this column would be for the label.\n",
" frac_train : float\n",
" frac_val : float\n",
" frac_test : float\n",
" The ratios with which the dataframe will be split into train, val, and\n",
" test data. The values should be expressed as float fractions and should\n",
" sum to 1.0.\n",
" random_state : int, None, or RandomStateInstance\n",
" Value to be passed to train_test_split().\n",
"\n",
" Returns\n",
" -------\n",
" df_train, df_val, df_test :\n",
" Dataframes containing the three splits.\n",
" \"\"\"\n",
"\n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
"\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
"\n",
" X = df_input # Contains all columns.\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
"\n",
" # Split original dataframe into train and temp dataframes.\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
"\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
"\n",
" # Split the temp dataframe into val and test dataframes.\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
"\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>22159</th>\n",
" <td>Female</td>\n",
" <td>54.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>97.06</td>\n",
" <td>28.5</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8920</th>\n",
" <td>Female</td>\n",
" <td>51.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>76.35</td>\n",
" <td>33.5</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>65507</th>\n",
" <td>Male</td>\n",
" <td>33.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>55.72</td>\n",
" <td>38.2</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>43196</th>\n",
" <td>Female</td>\n",
" <td>52.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>59.54</td>\n",
" <td>42.2</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>59745</th>\n",
" <td>Female</td>\n",
" <td>27.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>76.74</td>\n",
" <td>53.9</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>66546</th>\n",
" <td>Female</td>\n",
" <td>20.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>80.08</td>\n",
" <td>25.1</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68798</th>\n",
" <td>Female</td>\n",
" <td>58.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>59.86</td>\n",
" <td>28.0</td>\n",
" <td>formerly smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61409</th>\n",
" <td>Male</td>\n",
" <td>32.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Govt_job</td>\n",
" <td>Urban</td>\n",
" <td>58.24</td>\n",
" <td>NaN</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69259</th>\n",
" <td>Female</td>\n",
" <td>77.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>100.85</td>\n",
" <td>29.5</td>\n",
" <td>smokes</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17231</th>\n",
" <td>Female</td>\n",
" <td>24.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>90.42</td>\n",
" <td>24.3</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4088 rows × 11 columns</p>\n",
"</div>"
],
"text/plain": [
" gender age hypertension heart_disease ever_married work_type \\\n",
"id \n",
"22159 Female 54.0 1 0 No Private \n",
"8920 Female 51.0 0 0 Yes Self-employed \n",
"65507 Male 33.0 0 0 Yes Private \n",
"43196 Female 52.0 0 0 Yes Self-employed \n",
"59745 Female 27.0 0 0 Yes Private \n",
"... ... ... ... ... ... ... \n",
"66546 Female 20.0 0 0 No Private \n",
"68798 Female 58.0 0 0 Yes Private \n",
"61409 Male 32.0 1 0 No Govt_job \n",
"69259 Female 77.0 0 0 Yes Private \n",
"17231 Female 24.0 0 0 No Private \n",
"\n",
" Residence_type avg_glucose_level bmi smoking_status stroke \n",
"id \n",
"22159 Urban 97.06 28.5 formerly smoked 0 \n",
"8920 Rural 76.35 33.5 formerly smoked 0 \n",
"65507 Rural 55.72 38.2 never smoked 0 \n",
"43196 Urban 59.54 42.2 Unknown 0 \n",
"59745 Urban 76.74 53.9 Unknown 0 \n",
"... ... ... ... ... ... \n",
"66546 Urban 80.08 25.1 never smoked 0 \n",
"68798 Rural 59.86 28.0 formerly smoked 1 \n",
"61409 Urban 58.24 NaN formerly smoked 0 \n",
"69259 Rural 100.85 29.5 smokes 0 \n",
"17231 Urban 90.42 24.3 never smoked 0 \n",
"\n",
"[4088 rows x 11 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>22159</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8920</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>65507</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>43196</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>59745</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>66546</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68798</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61409</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69259</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17231</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4088 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" stroke\n",
"id \n",
"22159 0\n",
"8920 0\n",
"65507 0\n",
"43196 0\n",
"59745 0\n",
"... ...\n",
"66546 0\n",
"68798 1\n",
"61409 0\n",
"69259 0\n",
"17231 0\n",
"\n",
"[4088 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>18072</th>\n",
" <td>Female</td>\n",
" <td>39.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Govt_job</td>\n",
" <td>Urban</td>\n",
" <td>107.47</td>\n",
" <td>21.3</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>67063</th>\n",
" <td>Male</td>\n",
" <td>62.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>130.56</td>\n",
" <td>36.1</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40387</th>\n",
" <td>Female</td>\n",
" <td>17.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>77.46</td>\n",
" <td>24.0</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18032</th>\n",
" <td>Male</td>\n",
" <td>62.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>90.61</td>\n",
" <td>25.8</td>\n",
" <td>smokes</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5478</th>\n",
" <td>Female</td>\n",
" <td>60.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>203.04</td>\n",
" <td>NaN</td>\n",
" <td>smokes</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>57710</th>\n",
" <td>Female</td>\n",
" <td>50.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>112.25</td>\n",
" <td>21.6</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63043</th>\n",
" <td>Female</td>\n",
" <td>27.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>61.80</td>\n",
" <td>26.8</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63986</th>\n",
" <td>Male</td>\n",
" <td>60.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>153.48</td>\n",
" <td>37.3</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28461</th>\n",
" <td>Male</td>\n",
" <td>15.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Never_worked</td>\n",
" <td>Rural</td>\n",
" <td>79.59</td>\n",
" <td>28.4</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>54975</th>\n",
" <td>Male</td>\n",
" <td>7.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>64.06</td>\n",
" <td>18.9</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1022 rows × 11 columns</p>\n",
"</div>"
],
"text/plain": [
" gender age hypertension heart_disease ever_married work_type \\\n",
"id \n",
"18072 Female 39.0 0 0 Yes Govt_job \n",
"67063 Male 62.0 0 0 Yes Self-employed \n",
"40387 Female 17.0 0 0 No Private \n",
"18032 Male 62.0 0 1 Yes Private \n",
"5478 Female 60.0 0 0 Yes Self-employed \n",
"... ... ... ... ... ... ... \n",
"57710 Female 50.0 0 0 Yes Private \n",
"63043 Female 27.0 0 0 No Private \n",
"63986 Male 60.0 0 0 Yes Private \n",
"28461 Male 15.0 0 0 No Never_worked \n",
"54975 Male 7.0 0 0 No Self-employed \n",
"\n",
" Residence_type avg_glucose_level bmi smoking_status stroke \n",
"id \n",
"18072 Urban 107.47 21.3 Unknown 0 \n",
"67063 Urban 130.56 36.1 Unknown 0 \n",
"40387 Rural 77.46 24.0 Unknown 0 \n",
"18032 Rural 90.61 25.8 smokes 0 \n",
"5478 Urban 203.04 NaN smokes 0 \n",
"... ... ... ... ... ... \n",
"57710 Rural 112.25 21.6 Unknown 0 \n",
"63043 Urban 61.80 26.8 formerly smoked 0 \n",
"63986 Rural 153.48 37.3 never smoked 0 \n",
"28461 Rural 79.59 28.4 Unknown 0 \n",
"54975 Rural 64.06 18.9 Unknown 0 \n",
"\n",
"[1022 rows x 11 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>18072</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>67063</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40387</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18032</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5478</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>57710</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63043</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63986</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28461</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>54975</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1022 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" stroke\n",
"id \n",
"18072 0\n",
"67063 0\n",
"40387 0\n",
"18032 0\n",
"5478 0\n",
"... ...\n",
"57710 0\n",
"63043 0\n",
"63986 0\n",
"28461 0\n",
"54975 0\n",
"\n",
"[1022 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"stroke\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выберем ориентир для задачи классификации. Для этого применим алгоритм случайного предсказания, т.е. в каждом случае в качестве предсказания выберем случайный класс."
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Baseline Accuracy: 0.5205479452054794\n",
"Baseline Precision: 0.05823293172690763\n",
"Baseline Recall: 0.58\n",
"Baseline F1 Score: 0.10583941605839416\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score\n",
"\n",
"# Получаем уникальные классы для целевого признака из тренировочного набора данных\n",
"unique_classes = np.unique(y_train)\n",
"\n",
"# Генерируем случайные предсказания, выбирая случайное значение из области значений целевого признака\n",
"random_predictions = np.random.choice(unique_classes, size=len(y_test))\n",
"\n",
"# Вычисление метрик для ориентира\n",
"baseline_accuracy = accuracy_score(y_test, random_predictions)\n",
"baseline_precision = precision_score(y_test, random_predictions)\n",
"baseline_recall = recall_score(y_test, random_predictions)\n",
"baseline_f1 = f1_score(y_test, random_predictions)\n",
"\n",
"print('Baseline Accuracy:', baseline_accuracy)\n",
"print('Baseline Precision:', baseline_precision)\n",
"print('Baseline Recall:', baseline_recall)\n",
"print('Baseline F1 Score:', baseline_f1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Были использованы следующие метрики:\n",
"\n",
"- Accuracy: показывает долю верных предсказаний из общего числа примеров. Интуитивно понятная, но практически бесполезная в задачах с неравными классами метрика, т.к. не показывает насколько эффективно модель предсказывает более редкий класс.\n",
"- Precision: доля истинно положительных примеров среди всех предсказанных положительных. Precision показывает, насколько модель избирательна при предсказании положительного класса. Это важно, если ложные срабатывания нежелательны (например, ошибочное предсказание инсульта).\n",
"- Recall: доля объектов положительного класса из всех объектов положительного класса, которую нашел алгоритм. Recall показывает способность модели распознавать все примеры положительного класса. В рассматриваемой задаче высокое значение recall важно, так как минимизирует пропуски случаев инсульта.\n",
"- F1 Score: гармоническое среднее precision и recall, балансирует их значения. Эта метрика важна, если нам необходимо учитывать как точность, так и полноту модели. F1-score особенно полезен, если классы несбалансированы и важно найти баланс между обнаружением всех случаев (recall) и минимизацией ложных срабатываний (precision).\n",
"\n",
"Можно увидеть, что данные метрики охватывают разные аспекты работы модели, от способности распознавать редкие классы до общего уровня точности, что позволяет взглянуть на работу модели с разных сторон."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Сформируем конвейер для классификации"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"columns_to_drop = [\"work_type\", \"stroke\"]\n",
"columns_not_to_modify = [\"hypertension\", \"heart_disease\"]\n",
"\n",
"num_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop\n",
" and column not in columns_not_to_modify\n",
" and df[column].dtype != \"object\"\n",
"]\n",
"\n",
"cat_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop\n",
" and column not in columns_not_to_modify\n",
" and df[column].dtype == \"object\"\n",
"]\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Теперь проверим работу конвейера:"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>gender_Male</th>\n",
" <th>gender_Other</th>\n",
" <th>ever_married_Yes</th>\n",
" <th>Residence_type_Urban</th>\n",
" <th>smoking_status_formerly smoked</th>\n",
" <th>smoking_status_never smoked</th>\n",
" <th>smoking_status_smokes</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>22159</th>\n",
" <td>0.472344</td>\n",
" <td>-0.194427</td>\n",
" <td>-0.059214</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8920</th>\n",
" <td>0.339807</td>\n",
" <td>-0.653763</td>\n",
" <td>0.587887</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>65507</th>\n",
" <td>-0.455418</td>\n",
" <td>-1.111325</td>\n",
" <td>1.196162</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>43196</th>\n",
" <td>0.383986</td>\n",
" <td>-1.026600</td>\n",
" <td>1.713843</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>59745</th>\n",
" <td>-0.720492</td>\n",
" <td>-0.645113</td>\n",
" <td>3.228060</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>66546</th>\n",
" <td>-1.029746</td>\n",
" <td>-0.571034</td>\n",
" <td>-0.499243</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68798</th>\n",
" <td>0.649060</td>\n",
" <td>-1.019502</td>\n",
" <td>-0.123924</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61409</th>\n",
" <td>-0.499597</td>\n",
" <td>-1.055433</td>\n",
" <td>-0.098040</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69259</th>\n",
" <td>1.488464</td>\n",
" <td>-0.110367</td>\n",
" <td>0.070206</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17231</th>\n",
" <td>-0.853030</td>\n",
" <td>-0.341699</td>\n",
" <td>-0.602779</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4088 rows × 12 columns</p>\n",
"</div>"
],
"text/plain": [
" age avg_glucose_level bmi gender_Male gender_Other \\\n",
"id \n",
"22159 0.472344 -0.194427 -0.059214 0.0 0.0 \n",
"8920 0.339807 -0.653763 0.587887 0.0 0.0 \n",
"65507 -0.455418 -1.111325 1.196162 1.0 0.0 \n",
"43196 0.383986 -1.026600 1.713843 0.0 0.0 \n",
"59745 -0.720492 -0.645113 3.228060 0.0 0.0 \n",
"... ... ... ... ... ... \n",
"66546 -1.029746 -0.571034 -0.499243 0.0 0.0 \n",
"68798 0.649060 -1.019502 -0.123924 0.0 0.0 \n",
"61409 -0.499597 -1.055433 -0.098040 1.0 0.0 \n",
"69259 1.488464 -0.110367 0.070206 0.0 0.0 \n",
"17231 -0.853030 -0.341699 -0.602779 0.0 0.0 \n",
"\n",
" ever_married_Yes Residence_type_Urban smoking_status_formerly smoked \\\n",
"id \n",
"22159 0.0 1.0 1.0 \n",
"8920 1.0 0.0 1.0 \n",
"65507 1.0 0.0 0.0 \n",
"43196 1.0 1.0 0.0 \n",
"59745 1.0 1.0 0.0 \n",
"... ... ... ... \n",
"66546 0.0 1.0 0.0 \n",
"68798 1.0 0.0 1.0 \n",
"61409 0.0 1.0 1.0 \n",
"69259 1.0 0.0 0.0 \n",
"17231 0.0 1.0 0.0 \n",
"\n",
" smoking_status_never smoked smoking_status_smokes hypertension \\\n",
"id \n",
"22159 0.0 0.0 1 \n",
"8920 0.0 0.0 0 \n",
"65507 1.0 0.0 0 \n",
"43196 0.0 0.0 0 \n",
"59745 0.0 0.0 0 \n",
"... ... ... ... \n",
"66546 1.0 0.0 0 \n",
"68798 0.0 0.0 0 \n",
"61409 0.0 0.0 1 \n",
"69259 0.0 1.0 0 \n",
"17231 1.0 0.0 0 \n",
"\n",
" heart_disease \n",
"id \n",
"22159 0 \n",
"8920 0 \n",
"65507 0 \n",
"43196 0 \n",
"59745 0 \n",
"... ... \n",
"66546 0 \n",
"68798 0 \n",
"61409 0 \n",
"69259 0 \n",
"17231 0 \n",
"\n",
"[4088 rows x 12 columns]"
]
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подберем оптимальные гиперпараметры для каждой из выбранных моделей методом поиска по сетке и сформируем их набор.\n",
"\n",
"knn -- k-ближайших соседей\n",
"\n",
"random_forest -- метод случайного леса (набор деревьев решений)\n",
"\n",
"mlp -- многослойный персептрон (нейронная сеть)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие параметры для knn: {'n_neighbors': 1, 'weights': 'uniform'}\n",
"Лучшие параметры для random_forest: {'class_weight': 'balanced_subsample', 'criterion': 'entropy', 'max_depth': 7, 'max_features': 'sqrt', 'n_estimators': 50, 'random_state': 9}\n",
"Лучшие параметры для mlp: {'alpha': np.float64(0.1), 'early_stopping': True, 'hidden_layer_sizes': np.int64(14), 'max_iter': 1000, 'random_state': 9, 'solver': 'adam'}\n"
]
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn import neighbors, ensemble, neural_network\n",
"\n",
"# Словарь с вариантами гиперпараметров для каждой модели\n",
"param_grids = {\n",
" \"knn\": {\n",
" \"n_neighbors\": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], \n",
" \"weights\": ['uniform', 'distance']\n",
" },\n",
" \"random_forest\": {\n",
" \"n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" \"max_features\": [\"sqrt\", \"log2\", 2],\n",
" \"max_depth\": [2, 3, 4, 5, 6, 7, 8, 9, 10],\n",
" \"criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
" \"random_state\": [random_state],\n",
" \"class_weight\": [\"balanced\", \"balanced_subsample\"]\n",
" },\n",
" \"mlp\": {\n",
" \"solver\": ['adam'], \n",
" \"max_iter\": [1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000], \n",
" \"alpha\": 10.0 ** -np.arange(1, 10), \n",
" \"hidden_layer_sizes\":np.arange(10, 15), \n",
" \"early_stopping\": [True, False],\n",
" \"random_state\": [random_state]\n",
" }\n",
"}\n",
"\n",
"# Создаем экземпляры моделей\n",
"models = {\n",
" \"knn\": neighbors.KNeighborsClassifier(),\n",
" \"random_forest\": ensemble.RandomForestClassifier(),\n",
" \"mlp\": neural_network.MLPClassifier()\n",
"}\n",
"\n",
"# Словарь для хранения моделей с их лучшими параметрами\n",
"class_models = {}\n",
"\n",
"# Выполнение поиска по сетке для каждой модели\n",
"for model_name, model in models.items():\n",
" # Создаем GridSearchCV для текущей модели\n",
" gs_optimizer = GridSearchCV(estimator=model, param_grid=param_grids[model_name], scoring=\"f1\", n_jobs=-1)\n",
" \n",
" # Обучаем GridSearchCV\n",
" gs_optimizer.fit(preprocessed_df, y_train.values.ravel())\n",
" \n",
" # Получаем лучшие параметры\n",
" best_params = gs_optimizer.best_params_\n",
" print(f\"Лучшие параметры для {model_name}: {best_params}\")\n",
" \n",
" class_models[model_name] = {\n",
" \"model\": model.set_params(**best_params) # Настраиваем модель с лучшими параметрами\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Далее обучим модели и оценим их качество."
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: knn\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"from sklearn.metrics import confusion_matrix\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = accuracy_score(\n",
" y_test, y_test_predict\n",
" ) \n",
" class_models[model_name][\"F1_train\"] = f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"Confusion_matrix\"] = confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Матрицы неточностей:"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABBIAAANrCAYAAAD70rtBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAC8qklEQVR4nOzdfXzN9f/H8efZ9Ww7G2IzZsjlihQ1QynGSKX4VWrlIvFNVBSVvrloyqILIhddyMU36qsLvkIKhTCSUkKuQ9gUZi7a1Tmf3x9y6pwNZ2ezq8/jfrt9brfO+/N+f87r7Hu+28vr836/PxbDMAwBAAAAAAC4waukAwAAAAAAAGUHhQQAAAAAAOA2CgkAAAAAAMBtFBIAAAAAAIDbKCQAAAAAAAC3UUgAAAAAAABuo5AAAAAAAADcRiEBAAAAAAC4jUICAAAAAABwG4UEAJKkUaNGyWKx6I8//ijpUAAAQDGyWCwaNWpUSYdxQf/5z3/UsGFD+fr6KiwsrKTDASAKCQAAAABKqV9++UW9evXSlVdeqXfeeUdvv/12SYeUx+HDhzVq1Cht3ry5pEMBio1PSQcAAAAAAPlZuXKl7Ha73njjDdWtW7ekw8nX4cOH9cILL6hWrVpq2rRpSYcDFAtmJAAAAAClyJkzZ0o6hFLj6NGjklSkSxrOnj1bZNcCzIpCAoAL2r9/v+rWraurr75aaWlpuvnmm3X11Vdr27ZtuuWWW1ShQgVVr15d48aNcxq3cuVKWSwWzZs3Ty+99JJq1KihgIAAtWvXTrt37y6hTwMAQOlzfo+ibdu26f7771fFihXVunVr/fTTT+rVq5fq1KmjgIAARURE6KGHHtKxY8fyHb9792716tVLYWFhCg0NVe/evfP8gzkrK0uDBw9WlSpVFBISojvuuEO//fZbvnH98MMP6tSpk6xWq4KDg9WuXTutX7/eqc/MmTNlsVi0Zs0aPf7446pSpYrCwsL0r3/9S9nZ2UpPT1ePHj1UsWJFVaxYUU8//bQMw3D7Z1OrVi2NHDlSklSlSpU8ezlMmTJFV111lfz9/RUZGakBAwYoPT3d6Rrnc5dNmzbppptuUoUKFfTcc885fh4jR45U3bp15e/vr6ioKD399NPKyspyusayZcvUunVrhYWFKTg4WA0aNHBcY+XKlbr++uslSb1795bFYpHFYtHMmTPd/pxAWcTSBgD52rNnj9q2batKlSpp2bJluuKKKyRJJ06cUMeOHdW1a1fdc889+vjjj/XMM8+ocePG6tSpk9M1Xn75ZXl5eWnIkCE6efKkxo0bp8TERG3YsKEkPhIAAKXW3XffrXr16mnMmDEyDEPLli3T3r171bt3b0VERGjr1q16++23tXXrVq1fv14Wi8Vp/D333KPatWsrOTlZ33//vd59911VrVpVY8eOdfR5+OGH9f777+v+++9Xy5Yt9dVXX6lz5855Ytm6datuvPFGWa1WPf300/L19dVbb72lm2++WatWrVJsbKxT/8cee0wRERF64YUXtH79er399tsKCwvTunXrVLNmTY0ZM0ZLlizRK6+8oquvvlo9evRw62cyYcIEzZ49W/Pnz9fUqVMVHBysJk2aSDpXQHnhhRcUHx+v/v37a8eOHZo6dao2btyotWvXytfX13GdY8eOqVOnTurevbseeOABhYeHy26364477tCaNWvUr18/NWrUSFu2bNH48eO1c+dOLViwwPGzuO2229SkSRMlJSXJ399fu3fv1tq1ayVJjRo1UlJSkkaMGKF+/frpxhtvlCS1bNnSrc8IlFkGABiGMXLkSEOS8fvvvxvbt283IiMjjeuvv944fvy4o0+bNm0MScbs2bMdbVlZWUZERITRrVs3R9vXX39tSDIaNWpkZGVlOdrfeOMNQ5KxZcuW4vlQAACUcuf//t53331O7WfPns3T94MPPjAkGatXr84z/qGHHnLqe9dddxmVK1d2vN68ebMhyXj00Ued+t1///2GJGPkyJGOtjvvvNPw8/Mz9uzZ42g7fPiwERISYtx0002OthkzZhiSjISEBMNutzva4+LiDIvFYjzyyCOOttzcXKNGjRpGmzZtLvETcfbP/OS8o0ePGn5+fkaHDh0Mm83maH/zzTcNScZ7773naDufu0ybNs3puv/5z38MLy8v45tvvnFqnzZtmiHJWLt2rWEYhjF+/Pg87+9q48aNhiRjxowZBfpsQFnG0gYATn7++We1adNGtWrV0vLly1WxYkWn88HBwXrggQccr/38/HTDDTdo7969ea7Vu3dv+fn5OV6fr9Ln1xcAADN75JFHnF4HBgY6/jszM1N//PGHWrRoIUn6/vvvLzn+xhtv1LFjx5SRkSFJWrJkiSTp8ccfd+o3aNAgp9c2m01ffvml7rzzTtWpU8fRXq1aNd1///1as2aN45rn9enTx2mGRGxsrAzDUJ8+fRxt3t7eat68eZHkAMuXL1d2drYGDRokL6+//znTt29fWa1WLV682Km/v7+/evfu7dT20UcfqVGjRmrYsKH++OMPx9G2bVtJ0tdffy3p770Z/ve//8lutxc6dqC8oJAAwMntt9+ukJAQffHFF7JarXnO16hRI890yooVK+rEiRN5+tasWTNPP0n59gUAwMxq167t9Pr48eN64oknFB4ersDAQFWpUsXR5+TJk3nGX+pv7v79++Xl5aUrr7zSqV+DBg2cXv/+++86e/Zsnnbp3DR+u92ugwcPXvS9Q0NDJUlRUVF52osiB9i/f3++sfv5+alOnTqO8+dVr17d6caGJO3atUtbt25VlSpVnI769etL+nuTx3vvvVetWrXSww8/rPDwcHXv3l3z5s2jqADTY48EAE66deumWbNmac6cOfrXv/6V57y3t3e+44x8Nk8qSF8AAMzsnzMQpHN7Hqxbt05Dhw5V06ZNFRwcLLvdro4dO+b7j9iS/Jt7offOr70kcgDXn60k2e12NW7cWK+//nq+Y84XQQIDA7V69Wp9/fXXWrx4sZYuXar//ve/atu2rb788ssLfnagvKOQAMDJK6+8Ih8fHz366KMKCQnR/fffX9IhAQBgKidOnNCKFSv0wgsvaMSIEY72Xbt2eXzN6Oho2e127dmzx+lO/o4dO5z6ValSRRUqVMjTLkm//PKLvLy88sw0KG7R0dGSzsX+z+UX2dnZ2rdvn+Lj4y95jSuvvFI//vij2rVrl2empSsvLy+1a9dO7dq10+uvv64xY8bo3//+t77++mvFx8dfcjxQHrG0AYATi8Wit99+W//3f/+nnj17auHChSUdEgAApnL+Lrfr3fsJEyZ4fM3zT1aaOHHiRa/p7e2tDh066H//+59+/fVXR3taWprmzp2r1q1b57v0sTjFx8fLz89PEydOdPoZTZ8+XSdPnsz3SRSu7rnnHh06dEjvvPNOnnN//vmnzpw5I+ncEhNXTZs2lSTHYyKDgoIkKc+jJ4HyjBkJAPLw8vLS+++/rzvvvFP33HOPlixZ4th8CAAAXF5Wq1U33XSTxo0bp5ycHFWvXl1ffvml9u3b5/E1mzZtqvvuu09TpkzRyZMn1bJlS61YsUK7d+/O0/fFF1/UsmXL1Lp1az366KPy8fHRW2+9paysLI0bN64wH61IVKlSRcOGDdMLL7ygjh076o477tCOHTs0ZcoUXX/99U6bQl/Igw8+qHnz5umRRx7R119/rVatWslms+mXX37RvHnz9MUXX6h58+ZKSkrS6tWr1blzZ0VHR+vo0aOaMmWKatSoodatW0s6N7shLCxM06ZNU0hIiIKCghQbG5tn3wugPKGQACBfvr6++vjjj9WpUyd16dJFy5cvL+mQAAAwjblz5+qxxx7T5MmTZRiGOnTooM8//1yRkZEeX/O9995TlSpVNGfOHC1YsEBt27bV4sWL8yxVuOqqq/TNN99o2LBhSk5Olt1uV2xsrN5//33FxsYW9qMViVGjRqlKlSp68803NXjwYFWqVEn9+vXTmDFj5Ovre8nxXl5eWrBggcaPH6/Zs2dr/vz5qlChgurUqaMnnnjCseniHXfcoV9//VXvvfee/vjjD11xxRV
"text/plain": [
"<Figure size 1200x1000 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"fig, ax = plt.subplots(2, 2, figsize=(12, 10))\n",
"\n",
"for index, (key, model_info) in enumerate(class_models.items()):\n",
" c_matrix = model_info[\"Confusion_matrix\"]\n",
" \n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Not stroke\", \"Stroke\"]\n",
" ).plot(ax=ax.flat[index])\n",
" \n",
" disp.ax_.set_title(key)\n",
"\n",
"if len(class_models) < len(ax.flat):\n",
" for i in range(len(class_models), len(ax.flat)):\n",
" fig.delaxes(ax.flat[i])\n",
"\n",
"plt.subplots_adjust(top=0.9, bottom=0.1, hspace=0.4, wspace=0.3)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Precision, Recall, Accuracy, F1:"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_a7559_row0_col0 {\n",
" background-color: #1f988b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a7559_row0_col1, #T_a7559_row1_col0, #T_a7559_row1_col2, #T_a7559_row2_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_a7559_row0_col2, #T_a7559_row0_col3, #T_a7559_row1_col1, #T_a7559_row2_col0 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a7559_row0_col4 {\n",
" background-color: #b7318a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a7559_row0_col5, #T_a7559_row1_col4, #T_a7559_row1_col6, #T_a7559_row2_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a7559_row0_col6, #T_a7559_row0_col7, #T_a7559_row2_col4, #T_a7559_row2_col5 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a7559_row1_col3, #T_a7559_row2_col1 {\n",
" background-color: #1f968b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a7559_row1_col5 {\n",
" background-color: #be3885;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a7559_row1_col7 {\n",
" background-color: #9c179e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a7559_row2_col2 {\n",
" background-color: #86d549;\n",
" color: #000000;\n",
"}\n",
"#T_a7559_row2_col6 {\n",
" background-color: #8808a6;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_a7559\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_a7559_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_a7559_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_a7559_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_a7559_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_a7559_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_a7559_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_a7559_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_a7559_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_a7559_level0_row0\" class=\"row_heading level0 row0\" >mlp</th>\n",
" <td id=\"T_a7559_row0_col0\" class=\"data row0 col0\" >0.400000</td>\n",
" <td id=\"T_a7559_row0_col1\" class=\"data row0 col1\" >0.200000</td>\n",
" <td id=\"T_a7559_row0_col2\" class=\"data row0 col2\" >0.020101</td>\n",
" <td id=\"T_a7559_row0_col3\" class=\"data row0 col3\" >0.020000</td>\n",
" <td id=\"T_a7559_row0_col4\" class=\"data row0 col4\" >0.950832</td>\n",
" <td id=\"T_a7559_row0_col5\" class=\"data row0 col5\" >0.948141</td>\n",
" <td id=\"T_a7559_row0_col6\" class=\"data row0 col6\" >0.038278</td>\n",
" <td id=\"T_a7559_row0_col7\" class=\"data row0 col7\" >0.036364</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a7559_level0_row1\" class=\"row_heading level0 row1\" >knn</th>\n",
" <td id=\"T_a7559_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_a7559_row1_col1\" class=\"data row1 col1\" >0.117647</td>\n",
" <td id=\"T_a7559_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_a7559_row1_col3\" class=\"data row1 col3\" >0.120000</td>\n",
" <td id=\"T_a7559_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_a7559_row1_col5\" class=\"data row1 col5\" >0.912916</td>\n",
" <td id=\"T_a7559_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_a7559_row1_col7\" class=\"data row1 col7\" >0.118812</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a7559_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
" <td id=\"T_a7559_row2_col0\" class=\"data row2 col0\" >0.228869</td>\n",
" <td id=\"T_a7559_row2_col1\" class=\"data row2 col1\" >0.135135</td>\n",
" <td id=\"T_a7559_row2_col2\" class=\"data row2 col2\" >0.884422</td>\n",
" <td id=\"T_a7559_row2_col3\" class=\"data row2 col3\" >0.500000</td>\n",
" <td id=\"T_a7559_row2_col4\" class=\"data row2 col4\" >0.849315</td>\n",
" <td id=\"T_a7559_row2_col5\" class=\"data row2 col5\" >0.818982</td>\n",
" <td id=\"T_a7559_row2_col6\" class=\"data row2 col6\" >0.363636</td>\n",
" <td id=\"T_a7559_row2_col7\" class=\"data row2 col7\" >0.212766</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x23c21f18560>"
]
},
"execution_count": 92,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Краткий анализ метрик:\n",
"\n",
"1. MLP (многослойный перцептрон)\n",
"\n",
" Precision (точность) на обучении: 0.40, на тесте: 0.20\n",
"\n",
" Recall (полнота) на обучении: 0.02, на тесте: 0.02\n",
"\n",
" Accuracy (точность) на обучении: 0.95, на тесте: 0.95\n",
"\n",
" F1-метрика на обучении: 0.038, на тесте: 0.037\n",
"\n",
" Вывод: высокая точность на обучении и тесте указывает на хорошую способность модели правильно определять общий класс. Однако низкие значения precision, recall и F1-метрики говорят о сильном смещении: модель плохо справляется с выявлением положительных примеров.\n",
"\n",
"2. KNN (Метод K-ближайших соседей)\n",
"\n",
" Precision на обучении: 1.0, на тесте: 0.118\n",
"\n",
" Recall на обучении: 1.0, на тесте: 0.12\n",
"\n",
" Accuracy на обучении: 1.0, на тесте: 0.91\n",
"\n",
" F1-метрика на обучении: 1.0, на тесте: 0.119\n",
"\n",
" Вывод: модель показывает явное переобучение. Она идеально предсказывает на обучающем наборе, но значительно теряет точность на тестовых данных.\n",
"\n",
"3. Random Forest (Случайный лес)\n",
"\n",
" Precision на обучении: 0.229, на тесте: 0.135\n",
"\n",
" Recall на обучении: 0.88, на тесте: 0.50\n",
"\n",
" Accuracy на обучении: 0.85, на тесте: 0.82\n",
"\n",
" F1-метрика на обучении: 0.364, на тесте: 0.213\n",
"\n",
" Вывод: модель по сравнению с остальными вариантами показывает сбалансированные значения метрик, но их сложно назвать хорошими. Так, precision остается достаточно низким, что указывает на необходимость улучшения способности к идентификации положительных примеров.\n",
"\n",
"Сравнение с ориентиром.\n",
"\n",
"- Baseline Accuracy: 0.52\n",
"- Baseline Precision: 0.058\n",
"- Baseline Recall: 0.58\n",
"- Baseline F1 Score: 0.106\n",
"\n",
"Accuracy: все модели (особенно MLP и KNN) значительно превосходят базовую модель по точности. Random Forest также превосходит базовую модель, но не так явно.\n",
"\n",
"Precision: все модели лучше, чем базовая модель, хотя точность остается низкой. Особенно низкие значения у KNN и Random Forest.\n",
"\n",
"Recall: базовая модель показывает лучший recall, чем MLP и KNN. Это указывает на то, что обе модели (особенно MLP) с трудом находят положительные примеры. Random Forest лучше справляется с этой задачей.\n",
"\n",
"F1 Score: Random Forest показывает наилучшую F1-метрику, указывая на баланс между precision и recall, но она все еще значительно ниже желаемого уровня.\n",
"\n",
"Выводы о смещении и дисперсии:\n",
"\n",
"MLP: модель сильно смещена, поскольку плохо распознает положительные примеры, несмотря на высокую общую точность.\n",
"\n",
"KNN: высокая дисперсия, модель сильно переобучена на обучающем наборе и плохо обобщает на тестовом.\n",
"\n",
"Random Forest: наиболее сбалансированная модель с умеренным смещением и дисперсией. Она показывает лучший баланс между precision и recall, хотя precision остается невысоким.\n",
"\n",
"Заключение:\n",
"\n",
"Самой качественной моделью в данном случае можно назвать Random Forest, так как она показывает лучший баланс между различными метриками, но при этом и данная модель далека от идеала."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Регрессия\n",
"\n",
"Разделим набор данных на на обучающую и тестовые выборки (80/20). Целевой признак - avg_glucose_level"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>13276</th>\n",
" <td>Female</td>\n",
" <td>38.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>22.6</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21346</th>\n",
" <td>Female</td>\n",
" <td>12.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>children</td>\n",
" <td>Rural</td>\n",
" <td>17.8</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>59178</th>\n",
" <td>Female</td>\n",
" <td>7.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>children</td>\n",
" <td>Urban</td>\n",
" <td>22.3</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1679</th>\n",
" <td>Male</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>NaN</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1534</th>\n",
" <td>Female</td>\n",
" <td>61.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>26.1</td>\n",
" <td>smokes</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30463</th>\n",
" <td>Male</td>\n",
" <td>29.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>29.4</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>41935</th>\n",
" <td>Male</td>\n",
" <td>34.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>33.9</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68483</th>\n",
" <td>Female</td>\n",
" <td>60.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>41.2</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38617</th>\n",
" <td>Male</td>\n",
" <td>28.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>29.9</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46527</th>\n",
" <td>Male</td>\n",
" <td>53.0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Govt_job</td>\n",
" <td>Rural</td>\n",
" <td>41.9</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4088 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" gender age hypertension heart_disease ever_married work_type \\\n",
"id \n",
"13276 Female 38.0 0 0 Yes Private \n",
"21346 Female 12.0 0 0 No children \n",
"59178 Female 7.0 0 0 No children \n",
"1679 Male 35.0 0 0 Yes Private \n",
"1534 Female 61.0 0 0 Yes Private \n",
"... ... ... ... ... ... ... \n",
"30463 Male 29.0 0 0 No Private \n",
"41935 Male 34.0 0 0 No Private \n",
"68483 Female 60.0 0 0 Yes Private \n",
"38617 Male 28.0 0 0 Yes Self-employed \n",
"46527 Male 53.0 1 1 Yes Govt_job \n",
"\n",
" Residence_type bmi smoking_status stroke \n",
"id \n",
"13276 Urban 22.6 Unknown 0 \n",
"21346 Rural 17.8 Unknown 0 \n",
"59178 Urban 22.3 Unknown 0 \n",
"1679 Rural NaN formerly smoked 0 \n",
"1534 Rural 26.1 smokes 0 \n",
"... ... ... ... ... \n",
"30463 Urban 29.4 formerly smoked 0 \n",
"41935 Rural 33.9 never smoked 0 \n",
"68483 Urban 41.2 formerly smoked 0 \n",
"38617 Urban 29.9 never smoked 0 \n",
"46527 Rural 41.9 never smoked 0 \n",
"\n",
"[4088 rows x 10 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"id\n",
"13276 71.06\n",
"21346 70.13\n",
"59178 86.75\n",
"1679 77.48\n",
"1534 99.35\n",
" ... \n",
"30463 82.93\n",
"41935 125.29\n",
"68483 65.38\n",
"38617 73.98\n",
"46527 109.51\n",
"Name: avg_glucose_level, Length: 4088, dtype: float64"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>8385</th>\n",
" <td>Male</td>\n",
" <td>37.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>35.9</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>937</th>\n",
" <td>Male</td>\n",
" <td>7.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>children</td>\n",
" <td>Urban</td>\n",
" <td>NaN</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3494</th>\n",
" <td>Female</td>\n",
" <td>80.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>26.7</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23850</th>\n",
" <td>Male</td>\n",
" <td>66.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>33.1</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31156</th>\n",
" <td>Female</td>\n",
" <td>49.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>29.8</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>71010</th>\n",
" <td>Female</td>\n",
" <td>80.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>22.8</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>39518</th>\n",
" <td>Female</td>\n",
" <td>20.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>20.7</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7780</th>\n",
" <td>Male</td>\n",
" <td>51.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>30.7</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>56137</th>\n",
" <td>Female</td>\n",
" <td>62.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>36.3</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33175</th>\n",
" <td>Female</td>\n",
" <td>57.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Govt_job</td>\n",
" <td>Urban</td>\n",
" <td>28.5</td>\n",
" <td>Unknown</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1022 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" gender age hypertension heart_disease ever_married work_type \\\n",
"id \n",
"8385 Male 37.0 0 0 Yes Private \n",
"937 Male 7.0 0 0 No children \n",
"3494 Female 80.0 0 0 Yes Private \n",
"23850 Male 66.0 0 0 Yes Private \n",
"31156 Female 49.0 0 0 Yes Private \n",
"... ... ... ... ... ... ... \n",
"71010 Female 80.0 0 0 No Self-employed \n",
"39518 Female 20.0 0 0 No Private \n",
"7780 Male 51.0 0 0 Yes Self-employed \n",
"56137 Female 62.0 0 0 Yes Private \n",
"33175 Female 57.0 0 0 Yes Govt_job \n",
"\n",
" Residence_type bmi smoking_status stroke \n",
"id \n",
"8385 Urban 35.9 Unknown 0 \n",
"937 Urban NaN Unknown 0 \n",
"3494 Rural 26.7 Unknown 0 \n",
"23850 Urban 33.1 never smoked 0 \n",
"31156 Urban 29.8 never smoked 0 \n",
"... ... ... ... ... \n",
"71010 Urban 22.8 never smoked 0 \n",
"39518 Rural 20.7 never smoked 0 \n",
"7780 Urban 30.7 never smoked 0 \n",
"56137 Urban 36.3 Unknown 0 \n",
"33175 Urban 28.5 Unknown 1 \n",
"\n",
"[1022 rows x 10 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"id\n",
"8385 90.78\n",
"937 87.94\n",
"3494 102.90\n",
"23850 103.01\n",
"31156 105.99\n",
" ... \n",
"71010 57.57\n",
"39518 78.94\n",
"7780 75.73\n",
"56137 88.32\n",
"33175 110.52\n",
"Name: avg_glucose_level, Length: 1022, dtype: float64"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"features = ['gender', 'age', 'hypertension', 'heart_disease', 'ever_married', 'work_type', 'Residence_type', 'bmi', 'smoking_status', 'stroke']\n",
"target = 'avg_glucose_level'\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(df[features], df[target], test_size=0.2, random_state=random_state)\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выберем ориентир для задачи регрессии. Для этого применим алгоритм правила нуля, т.е. в каждом случае в качестве предсказания выберем среднее значение из области значений целевого признака."
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Baseline RMSE: 44.12711275645952\n",
"Baseline RMAE: 5.662154850745081\n",
"Baseline R2: -0.0010729515309222393\n"
]
}
],
"source": [
"import math\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n",
"\n",
"# Базовое предсказание: среднее значение по y_train\n",
"baseline_predictions = [y_train.mean()] * len(y_test)\n",
"\n",
"# Вычисление метрик качества для ориентира\n",
"baseline_rmse = math.sqrt(\n",
" mean_squared_error(y_test, baseline_predictions)\n",
" )\n",
"baseline_rmae = math.sqrt(\n",
" mean_absolute_error(y_test, baseline_predictions)\n",
" )\n",
"baseline_r2 = r2_score(y_test, baseline_predictions)\n",
"\n",
"print('Baseline RMSE:', baseline_rmse)\n",
"print('Baseline RMAE:', baseline_rmae)\n",
"print('Baseline R2:', baseline_r2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Были использованы следующие метрики:\n",
"\n",
"- RMSE: корень из MSE. MSE (Mean Squared Error) — среднеквадратическая ошибка, квадрат отклонения между предсказанными и истинными значениями. MSE чувствительна к большим ошибкам, так как отклонения возводятся в квадрат. RMSE также штрафует за большие ошибки, но в отличие от MSE, масштаб ошибки аналогичен исходным данным, что облегчает интерпретацию. Это делает RMSE хорошим выбором для многих практических задач, где важна интерпретируемость результата.\n",
"- RMAE: корень из MAE. MAE (Mean Absolute Error) — средняя абсолютная ошибка. Она показывает среднее отклонение предсказаний от истинных значений. MAE менее чувствительна к выбросам по сравнению с MSE и RMSE. Это делает её предпочтительным вариантом, когда выбросы присутствуют в данных, но не должны сильно влиять на общую производительность модели.\n",
"- R2 (коэффициент детерминации) : R2 измеряет, какая доля вариативности зависимой переменной объясняется независимыми переменными в модели. Это хороший способ оценить адекватность модели: близость к 1 говорит о хорошем объяснении данных моделью. R2 лучше всего подходит для сравнения моделей с одинаковыми данными.\n",
"\n",
"Таким образом, результаты этих метрик для базового ориентира позволят оценить, насколько лучше (или хуже) модель по сравнению с простым предсказанием среднего значения."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Сформируем конвейер для регрессии"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [],
"source": [
"columns_to_drop = []\n",
"columns_not_to_modify = [\"hypertension\", \"heart_disease\", \"stroke\", \"avg_glucose_level\"]\n",
"\n",
"num_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop\n",
" and column not in columns_not_to_modify\n",
" and df[column].dtype != \"object\"\n",
"]\n",
"\n",
"cat_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop\n",
" and column not in columns_not_to_modify\n",
" and df[column].dtype == \"object\"\n",
"]\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"pipeline_end_reg = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Теперь проверим работу конвейера:"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>bmi</th>\n",
" <th>gender_Male</th>\n",
" <th>gender_Other</th>\n",
" <th>ever_married_Yes</th>\n",
" <th>work_type_Never_worked</th>\n",
" <th>work_type_Private</th>\n",
" <th>work_type_Self-employed</th>\n",
" <th>work_type_children</th>\n",
" <th>Residence_type_Urban</th>\n",
" <th>smoking_status_formerly smoked</th>\n",
" <th>smoking_status_never smoked</th>\n",
" <th>smoking_status_smokes</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>13276</th>\n",
" <td>-0.236211</td>\n",
" <td>-0.826056</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21346</th>\n",
" <td>-1.386874</td>\n",
" <td>-1.455413</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>59178</th>\n",
" <td>-1.608155</td>\n",
" <td>-0.865391</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1679</th>\n",
" <td>-0.368980</td>\n",
" <td>-0.104918</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1534</th>\n",
" <td>0.781682</td>\n",
" <td>-0.367150</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30463</th>\n",
" <td>-0.634518</td>\n",
" <td>0.065532</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>41935</th>\n",
" <td>-0.413236</td>\n",
" <td>0.655554</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68483</th>\n",
" <td>0.737426</td>\n",
" <td>1.612701</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38617</th>\n",
" <td>-0.678774</td>\n",
" <td>0.131090</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46527</th>\n",
" <td>0.427632</td>\n",
" <td>1.704482</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4088 rows × 16 columns</p>\n",
"</div>"
],
"text/plain": [
" age bmi gender_Male gender_Other ever_married_Yes \\\n",
"id \n",
"13276 -0.236211 -0.826056 0.0 0.0 1.0 \n",
"21346 -1.386874 -1.455413 0.0 0.0 0.0 \n",
"59178 -1.608155 -0.865391 0.0 0.0 0.0 \n",
"1679 -0.368980 -0.104918 1.0 0.0 1.0 \n",
"1534 0.781682 -0.367150 0.0 0.0 1.0 \n",
"... ... ... ... ... ... \n",
"30463 -0.634518 0.065532 1.0 0.0 0.0 \n",
"41935 -0.413236 0.655554 1.0 0.0 0.0 \n",
"68483 0.737426 1.612701 0.0 0.0 1.0 \n",
"38617 -0.678774 0.131090 1.0 0.0 1.0 \n",
"46527 0.427632 1.704482 1.0 0.0 1.0 \n",
"\n",
" work_type_Never_worked work_type_Private work_type_Self-employed \\\n",
"id \n",
"13276 0.0 1.0 0.0 \n",
"21346 0.0 0.0 0.0 \n",
"59178 0.0 0.0 0.0 \n",
"1679 0.0 1.0 0.0 \n",
"1534 0.0 1.0 0.0 \n",
"... ... ... ... \n",
"30463 0.0 1.0 0.0 \n",
"41935 0.0 1.0 0.0 \n",
"68483 0.0 1.0 0.0 \n",
"38617 0.0 0.0 1.0 \n",
"46527 0.0 0.0 0.0 \n",
"\n",
" work_type_children Residence_type_Urban \\\n",
"id \n",
"13276 0.0 1.0 \n",
"21346 1.0 0.0 \n",
"59178 1.0 1.0 \n",
"1679 0.0 0.0 \n",
"1534 0.0 0.0 \n",
"... ... ... \n",
"30463 0.0 1.0 \n",
"41935 0.0 0.0 \n",
"68483 0.0 1.0 \n",
"38617 0.0 1.0 \n",
"46527 0.0 0.0 \n",
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"id \n",
"13276 0.0 0.0 \n",
"21346 0.0 0.0 \n",
"59178 0.0 0.0 \n",
"1679 1.0 0.0 \n",
"1534 0.0 0.0 \n",
"... ... ... \n",
"30463 1.0 0.0 \n",
"41935 0.0 1.0 \n",
"68483 1.0 0.0 \n",
"38617 0.0 1.0 \n",
"46527 0.0 1.0 \n",
"\n",
" smoking_status_smokes hypertension heart_disease stroke \n",
"id \n",
"13276 0.0 0 0 0 \n",
"21346 0.0 0 0 0 \n",
"59178 0.0 0 0 0 \n",
"1679 0.0 0 0 0 \n",
"1534 1.0 0 0 0 \n",
"... ... ... ... ... \n",
"30463 0.0 0 0 0 \n",
"41935 0.0 0 0 0 \n",
"68483 0.0 0 0 0 \n",
"38617 0.0 0 0 0 \n",
"46527 0.0 1 1 0 \n",
"\n",
"[4088 rows x 16 columns]"
]
},
"execution_count": 96,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end_reg.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end_reg.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подберем оптимальные гиперпараметры для каждой из выбранных моделей методом поиска по сетке и сформируем их набор.\n",
"\n",
"knn -- k-ближайших соседей\n",
"\n",
"random_forest -- метод случайного леса (набор деревьев решений)\n",
"\n",
"mlp -- многослойный персептрон (нейронная сеть)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие параметры для knn: {'n_jobs': -1, 'n_neighbors': 30, 'weights': 'uniform'}\n",
"Лучшие параметры для random_forest: {'criterion': 'squared_error', 'max_depth': 7, 'max_features': 'sqrt', 'n_estimators': 250, 'n_jobs': -1, 'random_state': 9}\n",
"Лучшие параметры для mlp: {'alpha': np.float64(1e-06), 'early_stopping': False, 'hidden_layer_sizes': np.int64(13), 'max_iter': 1000, 'random_state': 9, 'solver': 'adam'}\n"
]
}
],
"source": [
"# Словарь с вариантами гиперпараметров для каждой модели\n",
"param_grids = {\n",
" \"knn\": {\n",
" \"n_neighbors\": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], \n",
" \"weights\": ['uniform', 'distance'],\n",
" \"n_jobs\": [-1]\n",
" },\n",
" \"random_forest\": {\n",
" \"n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" \"max_features\": [\"sqrt\", \"log2\", 2],\n",
" \"max_depth\": [2, 3, 4, 5, 6, 7, 8, 9, 10],\n",
" \"criterion\": [\"squared_error\", \"absolute_error\", \"poisson\"],\n",
" \"random_state\": [random_state],\n",
" \"n_jobs\": [-1]\n",
" },\n",
" \"mlp\": {\n",
" \"solver\": ['adam'], \n",
" \"max_iter\": [1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000], \n",
" \"alpha\": 10.0 ** -np.arange(1, 10), \n",
" \"hidden_layer_sizes\":np.arange(10, 15), \n",
" \"early_stopping\": [True, False],\n",
" \"random_state\": [random_state]\n",
" }\n",
"}\n",
"\n",
"# Создаем экземпляры моделей\n",
"models = {\n",
" \"knn\": neighbors.KNeighborsRegressor(),\n",
" \"random_forest\": ensemble.RandomForestRegressor(),\n",
" \"mlp\": neural_network.MLPRegressor()\n",
"}\n",
"\n",
"# Словарь для хранения моделей с их лучшими параметрами\n",
"class_models = {}\n",
"\n",
"# Выполнение поиска по сетке для каждой модели\n",
"for model_name, model in models.items():\n",
" # Создаем GridSearchCV для текущей модели\n",
" gs_optimizer = GridSearchCV(estimator=model, param_grid=param_grids[model_name], scoring='neg_mean_squared_error', n_jobs=-1)\n",
" \n",
" # Обучаем GridSearchCV\n",
" gs_optimizer.fit(preprocessed_df, y_train.values.ravel())\n",
" \n",
" # Получаем лучшие параметры\n",
" best_params = gs_optimizer.best_params_\n",
" print(f\"Лучшие параметры для {model_name}: {best_params}\")\n",
" \n",
" class_models[model_name] = {\n",
" \"model\": model.set_params(**best_params) # Настраиваем модель с лучшими параметрами\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Далее обучим модели и оценим их качество."
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: knn\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" \n",
" model = class_models[model_name][\"model\"]\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end_reg), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_pred = model_pipeline.predict(X_train)\n",
" y_test_pred = model_pipeline.predict(X_test)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"train_preds\"] = y_train_pred\n",
" class_models[model_name][\"preds\"] = y_test_pred\n",
" \n",
" class_models[model_name][\"RMSE_train\"] = math.sqrt(\n",
" mean_squared_error(y_train, y_train_pred)\n",
" )\n",
" class_models[model_name][\"RMSE_test\"] = math.sqrt(\n",
" mean_squared_error(y_test, y_test_pred)\n",
" )\n",
" class_models[model_name][\"RMAE_test\"] = math.sqrt(\n",
" mean_absolute_error(y_test, y_test_pred)\n",
" )\n",
" class_models[model_name][\"R2_test\"] = r2_score(y_test, y_test_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"RMSE, RMAE, R2:"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_0650d_row0_col0, #T_0650d_row2_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_0650d_row0_col1, #T_0650d_row1_col0 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0650d_row0_col2, #T_0650d_row2_col3 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0650d_row0_col3, #T_0650d_row2_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0650d_row1_col1 {\n",
" background-color: #20938c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0650d_row1_col2 {\n",
" background-color: #b42e8d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0650d_row1_col3 {\n",
" background-color: #c8437b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0650d_row2_col0 {\n",
" background-color: #73d056;\n",
" color: #000000;\n",
"}\n",
"</style>\n",
"<table id=\"T_0650d\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_0650d_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
" <th id=\"T_0650d_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
" <th id=\"T_0650d_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
" <th id=\"T_0650d_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_0650d_level0_row0\" class=\"row_heading level0 row0\" >mlp</th>\n",
" <td id=\"T_0650d_row0_col0\" class=\"data row0 col0\" >42.583378</td>\n",
" <td id=\"T_0650d_row0_col1\" class=\"data row0 col1\" >40.922194</td>\n",
" <td id=\"T_0650d_row0_col2\" class=\"data row0 col2\" >5.533579</td>\n",
" <td id=\"T_0650d_row0_col3\" class=\"data row0 col3\" >0.139061</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0650d_level0_row1\" class=\"row_heading level0 row1\" >random_forest</th>\n",
" <td id=\"T_0650d_row1_col0\" class=\"data row1 col0\" >40.324186</td>\n",
" <td id=\"T_0650d_row1_col1\" class=\"data row1 col1\" >41.085298</td>\n",
" <td id=\"T_0650d_row1_col2\" class=\"data row1 col2\" >5.544678</td>\n",
" <td id=\"T_0650d_row1_col3\" class=\"data row1 col3\" >0.132184</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0650d_level0_row2\" class=\"row_heading level0 row2\" >knn</th>\n",
" <td id=\"T_0650d_row2_col0\" class=\"data row2 col0\" >42.164413</td>\n",
" <td id=\"T_0650d_row2_col1\" class=\"data row2 col1\" >41.826505</td>\n",
" <td id=\"T_0650d_row2_col2\" class=\"data row2 col2\" >5.550755</td>\n",
" <td id=\"T_0650d_row2_col3\" class=\"data row2 col3\" >0.100590</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x23c2371daf0>"
]
},
"execution_count": 99,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Результаты графиками:"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: knn\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAIjCAYAAAAJLyrXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3xT9foH8M852UnTvRhllb0EBXHjYCgKKo6Lyu+K4uIWEERFFEXFq4jIEHBeRa9XXNd1RUUQEVRQUEDBsnfpXtn7nN8foaHpgKZNmo7P+/XyJT1Jkydpcs73+Y7nK8iyLIOIiIiIiIjqTIx2AERERERERM0NEykiIiIiIqIQMZEiIiIiIiIKERMpIiIiIiKiEDGRIiIiIiIiChETKSIiIiIiohAxkSIiIiIiIgoREykiIiIiIqIQMZEiIiIiIiIKERMpIiJqcQRBwJNPPhny7x05cgSCIODtt9+u1/NOmDABMTEx9fpdIiJqXphIERFRRLz99tsQBAGCIOCnn36qdrssy8jIyIAgCLjmmmuiECEREVH9MZEiIqKI0mq1WLlyZbXjGzZsQE5ODjQaTRSiIiIiahgmUkREFFGjRo3Cxx9/DK/XG3R85cqVOOecc5Cenh6lyIiIiOqPiRQREUXULbfcgpKSEqxduzZwzO1247///S9uvfXWGn/HZrNhxowZyMjIgEajQY8ePbBgwQLIshx0P5fLhenTpyMlJQVGoxFjxoxBTk5OjY954sQJ3HnnnUhLS4NGo0GfPn3w1ltvnTF+j8eDPXv2IC8vL4RXfcqOHTuQkpKCSy+9FFarFQDQqVMnXHPNNfjpp59w7rnnQqvVokuXLvj3v/8d9LsV0yN//vlnPPDAA0hJSYHBYMD111+PoqKiesVDREThwUSKiIgiqlOnTjj//PPx/vvvB4598803MJlMGDduXLX7y7KMMWPGYNGiRbjyyiuxcOFC9OjRAw899BAeeOCBoPveddddWLx4MUaMGIF58+ZBpVLh6quvrvaYBQUFOO+88/Ddd99h8uTJWLJkCbp27YqJEydi8eLFp43/xIkT6NWrF2bNmhXya9+6dSsuv/xyDBw4EN98801QIYoDBw7gxhtvxPDhw/Hiiy8iISEBEyZMwF9//VXtcaZMmYI//vgDc+bMwaRJk/Dll19i8uTJIcdDRETho4x2AERE1PLdeuutmDVrFhwOB3Q6Hd577z0MHToUbdu2rXbf//3vf/j+++/xzDPP4LHHHgMAZGVl4aabbsKSJUswefJkZGZm4o8//sB//vMf/OMf/8Dy5csD97vtttvw559/Bj3mY489Bp/Ph507dyIpKQkAcN999+GWW27Bk08+iXvvvRc6nS6sr/nnn3/GqFGjcPHFF+OTTz6pthZs79692LhxIy6++GIAwM0334yMjAysWLECCxYsCLpvUlIS1qxZA0EQAACSJOGll16CyWRCXFxcWOMmIqK64YgUERFF3M033wyHw4FVq1bBYrFg1apVtU7r+/rrr6FQKDB16tSg4zNmzIAsy/jmm28C9wNQ7X7Tpk0L+lmWZXzyyScYPXo0ZFlGcXFx4L+RI0fCZDJh27ZttcbeqVMnyLIcUkn09evXY+TIkbjiiivw6aef1lhQo3fv3oEkCgBSUlLQo0cPHDp0qNp977nnnkASBQAXX3wxfD4fjh49WueYiIgovDgiRUREEZeSkoJhw4Zh5cqVsNvt8Pl8uPHGG2u879GjR9G2bVsYjcag47169QrcXvF/URSRmZkZdL8ePXoE/VxUVITy8nK8/vrreP3112t8zsLCwnq9rpo4nU5cffXVOOecc/DRRx9Bqaz5UtuhQ4dqxxISElBWVnbG+yYkJABAjfclIqLGwUSKiIgaxa233oq7774b+fn5uOqqqxAfH98ozytJEgBg/PjxuP3222u8T//+/cP2fBqNBqNGjcIXX3yB1atX17pHlkKhqPF41YIaod6XiIgaBxMpIiJqFNdffz3uvfde/PLLL/jwww9rvV/Hjh3x3XffwWKxBI1K7dmzJ3B7xf8lScLBgweDRqH27t0b9HgVFf18Ph+GDRsWzpdUI0EQ8N577+Haa6/FTTfdhG+++QaXXnppxJ+XiIgaF9dIERFRo4iJicErr7yCJ598EqNHj671fqNGjYLP58OyZcuCji9atAiCIOCqq64CgMD/X3rppaD7Va3Cp1AocMMNN+CTTz7Brl27qj3fmcqI16f8uVqtxqefforBgwdj9OjR2LJlS51/l4iImgeOSBERUaOpbWpdZaNHj8Zll12Gxx57DEeOHMFZZ52FNWvW4IsvvsC0adMCa6IGDBiAW265BS+//DJMJhMuuOACrFu3DgcOHKj2mPPmzcP69esxZMgQ3H333ejduzdKS0uxbds2fPfddygtLa01nory57fffntIBSd0Oh1WrVqFyy+/HFdddRU2bNiAvn371vn3iYioaeOIFBERNSmiKOJ///sfpk2bhlWrVmHatGnIzs7GCy+8gIULFwbd96233sLUqVOxevVqPPzww/B4PPjqq6+qPWZaWhq2bNmCO+64A59++mlgL6nS0lI8//zzEXstsbGx+Pbbb5Geno7hw4fXmOQREVHzJMhcqUpERERERBQSjkgRERERERGFiIkUERERERFRiJhIERERERERhYiJFBERERERUYiYSBEREREREYWIiRQREREREVGIuCEvAEmSkJubC6PRCEEQoh0OERERERFFiSzLsFgsaNu2LUSx9nEnJlIAcnNzkZGREe0wiIiIiIioiTh+/Djat29f6+1MpAAYjUYA/jcrNjY2ytEQEREREVG0mM1mZGRkBHKE2jCRAgLT+WJjY5lIERERERHRGZf8sNgEERERERFRiJhIERERERERhYiJFBERERERUYiYSBEREREREYWIiRQREREREVGImEgRERERERGFiIkUERERERFRiJhIERERERERhYiJFBERERERUYiYSBEREREREYWIiRQREREREVGImEgRERERERGFiIkUERERERFRiJhIERERERERhYiJFBERERERUYiYSBEREREREYWIiRQREREREUWFLMs4cOBAtMOoFyZSRERERETU6Hbt2oXLLrsMgwYNQmFhYbTDCRkTKSIiIiIiajQmkwnTpk3DgAEDsGHDBphMJsyaNSvaYYVMGe0AiIiIiIio5ZMkCe+++y4efvjhoBGozMxMjB07NoqR1Q8TKSIiIiIiiqgdO3YgKysLmzZtChzT6XR49NFH8eCDD0Kr1UYxuvphIkVERERERBFhMpnw6KOP4tVXX4UkSYHjY8eOxcKFC9GxY8coRtcwTKSIiIiIiCgiZFnGf//730AS1b17dyxduhQjRoyIcmQNx2ITREREREQUEfHx8Zg/fz4MBgPmzZuHnTt3togkCmAiRUREREREYVBSUoIpU6bgxIkTQcf/7//+D/v378fMmTOhVqujFF34cWofERERERHVm8/nwxtvvIHHHnsMpaWlKCkpwcqVKwO3i6KINm3aRDHCyOCIFBERERER1cvmzZtx7rnnYtKkSSgtLQUArFq1Cnl5eVGOLPKYSBERERERUUgKCwtx55134oILLsC2bdsCx8ePH4+9e/e2yBGoqji1j4iIiIiI6sTr9eKVV17B448/DpPJFDjer18/LF++HBdffHEUo2tcTKSIiIiIiKhOrrvuOnz11VeBn+Pi4jB37lxMmjQJSmXrSi04tY+IiIiIiOrk73//e+DfEyZMwN69ezFlypRWl0QBHJEiIiIiIqIaeDweWCwWJCYmBo7ddNNN+Omnn3DLLbfg/PPPj2J00ccRKSIiIiIiCvLDDz9g4MCBmDhxYtBxQRDw0ksvtfokCmAiRUREREREJ+Xk5OCWW27BZZddhr/++guff/45Vq9eHe2wmiQmUkRERERErZzb7cb8+fPRs2dPfPDBB4HjgwcPRmpqahQja7q4RoqIiIiIqBVbu3YtpkyZgr179waOJSUlYd68ebjzzjshihx7qQnfFSIiIiKiVujYsWO44YYbMGLEiEASJYoi/vGPf2Dfvn246667mES
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: random_forest\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAIjCAYAAAAJLyrXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3hT9f4H8Pc5WU3TNN0to6VQ9kZRRFQQGYKCigsUBcGFDEFUhB8KilcBkSG4ryLXK67ruqIyRMR7BS8ooGDZG0p3m6bZ4/z+CD02dNC0aZO079fz8NB+c5p8cnJyzvl8pyBJkgQiIiIiIiKqMTHYARAREREREYUbJlJERERERER+YiJFRERERETkJyZSREREREREfmIiRURERERE5CcmUkRERERERH5iIkVEREREROQnJlJERERERER+YiJFRERERETkJyZSRERULUEQMH/+fL//7sSJExAEAe+9917AY2oo4fAeXnrpJbRp0wYKhQI9e/YMdjhERE0GEykiojDw3nvvQRAECIKA//73vxUelyQJqampEAQBN954YxAipGDYuHEjnnzySfTr1w+rV6/GCy+8EOyQKsjMzMT8+fNx4sSJYIdCRBRQymAHQERENRcREYG1a9fiqquu8infunUrzpw5A41GE6TIKBh++OEHiKKId955B2q1OtjhVCozMxPPPvssBgwYgPT09GCHQ0QUMGyRIiIKI8OHD8enn34Kl8vlU7527VpceumlSElJCVJkDcNsNgc7hJCSm5sLrVYbsCRKkiRYrdaAPBcRUWPHRIqIKIyMGTMGBQUF2LRpk1zmcDjwr3/9C3fddVelf2M2mzFz5kykpqZCo9GgQ4cOWLJkCSRJ8tnObrdjxowZSExMhF6vx8iRI3HmzJlKn/Ps2bOYMGECkpOTodFo0KVLF7z77rsXjd/pdOLAgQM4d+7cRbcdP348oqKicPToUQwfPhx6vR533303AOA///kPbr/9dqSlpUGj0SA1NRUzZsyokASUPcfZs2dx8803IyoqComJiXj88cfhdrt9ti0uLsb48eNhMBgQExODcePGobi4uNLYfvjhB1x99dXQ6XSIiYnBTTfdhP379/tsM3/+fAiCgEOHDmHs2LEwGAxITEzE008/DUmScPr0adx0002Ijo5GSkoKXn755Yvuk/IEQcDq1athNpvlbp9lY7lcLhcWLFiAjIwMaDQapKenY86cObDb7T7PkZ6ejhtvvBEbNmxA7969odVq8eabb8r7Y/r06fJx07ZtWyxatAgej8fnOT766CNceuml0Ov1iI6ORrdu3bBixQoA3i6pt99+OwDg2muvleP88ccf/XqvREShiIkUEVEYSU9PR9++ffHhhx/KZd999x2MRiNGjx5dYXtJkjBy5EgsW7YM119/PZYuXYoOHTrgiSeewGOPPeaz7f3334/ly5djyJAhWLhwIVQqFW644YYKz5mTk4MrrrgC33//PaZMmYIVK1agbdu2mDhxIpYvX15t/GfPnkWnTp0we/bsGr1fl8uFoUOHIikpCUuWLMGtt94KAPj0009hsVgwadIkrFy5EkOHDsXKlStx7733VngOt9uNoUOHIj4+HkuWLEH//v3x8ssv46233vLZTzfddBPef/99jB07Fs8//zzOnDmDcePGVXi+77//HkOHDkVubi7mz5+Pxx57DNu2bUO/fv0qHQd05513wuPxYOHChejTpw+ef/55LF++HIMHD0aLFi2waNEitG3bFo8//jh++umnGu0XAHj//fdx9dVXQ6PR4P3338f777+Pa665BoD3s3zmmWdwySWXYNmyZejfvz9efPHFSo+RgwcPYsyYMRg8eDBWrFiBnj17wmKxoH///vjnP/+Je++9F6+88gr69euH2bNn+xw3mzZtwpgxYxAbG4tFixZh4cKFGDBgAH7++WcAwDXXXINp06YBAObMmSPH2alTpxq/TyKikCUREVHIW716tQRA2rlzp7Rq1SpJr9dLFotFkiRJuv3226Vrr71WkiRJatWqlXTDDTfIf/fll19KAKTnn3/e5/luu+02SRAE6ciRI5IkSdKePXskANIjjzzis91dd90lAZDmzZsnl02cOFFq1qyZlJ+f77Pt6NGjJYPBIMd1/PhxCYC0evVqeZuysnHjxl30PY8bN04CID311FMVHit7jfJefPFFSRAE6eTJkxWe47nnnvPZtlevXtKll14q/162nxYvXiyXuVwu6eqrr67wHnr27CklJSVJBQUFctnvv/8uiaIo3XvvvXLZvHnzJADSgw8+6POcLVu2lARBkBYuXCiXFxUVSVqttkb7pbxx48ZJOp3Op6zss7z//vt9yh9//HEJgPTDDz/IZa1atZIASOvXr/fZdsGCBZJOp5MOHTrkU/7UU09JCoVCOnXqlCRJkvToo49K0dHRksvlqjLGTz/9VAIgbdmyxa/3RkQU6tgiRUQUZu644w5YrVasW7cOJpMJ69atq7Jb37fffguFQiG3CpSZOXMmJEnCd999J28HoMJ206dP9/ldkiR89tlnGDFiBCRJQn5+vvxv6NChMBqN2LVrV5Wxp6enQ5Ikv6YTnzRpUoUyrVYr/2w2m5Gfn48rr7wSkiRh9+7dFbZ/+OGHfX6/+uqrcezYMfn3b7/9Fkql0ue1FAoFpk6d6vN3586dw549ezB+/HjExcXJ5d27d8fgwYPl/Vje/fff7/OcvXv3hiRJmDhxolweExODDh06+MRUW2UxXNjiOHPmTADAN99841PeunVrDB061Kfs008/xdVXX43Y2Fifz3jQoEFwu91yy1lMTAzMZrNPV1MioqaCs/YREYWZxMREDBo0CGvXroXFYoHb7cZtt91W6bYnT55E8+bNodfrfcrLuladPHlS/l8URWRkZPhs16FDB5/f8/LyUFxcjLfeesuna1x5ubm5tXpflVEqlWjZsmWF8lOnTuGZZ57Bv//9bxQVFfk8ZjQafX6PiIhAYmKiT1lsbKzP3508eRLNmjVDVFSUz3YXvv+y/XVhOeDdpxs2bIDZbIZOp5PL09LSfLYzGAyIiIhAQkJChfKCgoIKz+uvss+ybdu2PuUpKSmIiYmR30OZ1q1bV3iOw4cP448//qiw38qUfcaPPPIIPvnkEwwbNgwtWrTAkCFDcMcdd+D666+v8/sgIgp1TKSIiMLQXXfdhQceeADZ2dkYNmwYYmJiGuR1yyYaGDt2bKXjhwBv60ygaDQaiKJv5wm3243BgwejsLAQs2bNQseOHaHT6XD27FmMHz++wmQICoUiYPHURmWvX1VM0gUTgNSFIAg12q58614Zj8eDwYMH48knn6z0b9q3bw8ASEpKwp49e7BhwwZ89913+O6777B69Wrce++9WLNmTe2DJyIKA0ykiIjC0C233IKHHnoIv/zyCz7++OMqt2vVqhW+//57mEwmn1apAwcOyI+X/e/xeHD06FGf1paDBw/6PF/ZjH5utxuDBg0K5Fuqsb179+LQoUNYs2aNz+QSdele1qpVK2zevBmlpaU+rVIXvv+y/XVhOeDdpwkJCT6tUcFQ9lkePnzYZ1KHnJwcFBcXy++hOhkZGSgtLa3RZ6xWqzFixAiMGDECHo8HjzzyCN588008/fTTaNu2bY0TOiKicMMxUkREYSgqKgqvv/465s+fjxEjRlS53fDhw+F2u7Fq1Sqf8mXLlkEQBAwbNgwA5P9feeUVn+0unIVPoVDg1ltvxWeffYZ9+/ZVeL28vLxq4/Zn+vOqlLXmlG+9kSRJnnK7NoYPHw6Xy4XXX39dLnO73Vi5cqXPds2aNUPPnj2xZs0an6nR9+3bh40bN2L48OG1jiFQymK48LNbunQpAFQ6E+OF7rjjDmzfvh0bNmyo8FhxcbG8jtmFXRFFUZRbJMumWi9LLKuaSp6IKFyxRYqIKExV1bWuvBEjRuDaa6/F//3f/+HEiRPo0aMHNm7ciK+++grTp0+Xx0T17NkTY8aMwWuvvQaj0Ygrr7wSmzdvxpEjRyo858KFC7Flyxb06dMHDzzwADp37ozCwkLs2rUL33//PQoLC6uMp2z683H
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: mlp\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAIjCAYAAAAJLyrXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3hUZdoG8Puc6S2dJJRAIHQBQUFUXLGgKAqr2FBZC9gwgCAq4qrYVgGpAtZVdHdFXVdXP3FtiwgWXFFAQXqTkB6STG9nzvn+mGRkSAKZZCYzSe7fdXlJ3pnMPDOZOec8b3leQVEUBURERERERNRoYrwDICIiIiIiam2YSBEREREREUWIiRQREREREVGEmEgRERERERFFiIkUERERERFRhJhIERERERERRYiJFBERERERUYSYSBEREREREUWIiRQREREREVGEmEgREVGbIwgCHnvssYh/79ChQxAEAa+//nrUY2qMr776CoIg4KuvvorL8xMRUeMxkSIioph4/fXXIQgCBEHAN998U+d2RVGQk5MDQRBw+eWXxyFCIiKipmMiRUREMaXX67F69eo67evXr8eRI0eg0+niEBUREVHzMJEiIqKYGjNmDN59911IkhTWvnr1apx++unIzs6OU2RERERNx0SKiIhi6vrrr8fRo0fxxRdfhNp8Ph/+9a9/4YYbbqj3d5xOJ2bNmoWcnBzodDr06dMHCxcuhKIoYffzer2YOXMmOnToAIvFgnHjxuHIkSP1PmZhYSEmTZqErKws6HQ6nHLKKXjttddOGr/f78euXbtQXFx80vvecsstMJvNOHz4MC6//HKYzWZ07twZK1euBABs27YNF1xwAUwmE7p161bvSN3xzjvvPAwYMAA//fQTzj77bBgMBnTv3h0vvvjiSX+XiIhih4kUERHFVG5uLs466yy89dZbobZPPvkEVqsVEyZMqHN/RVEwbtw4LFmyBJdccgkWL16MPn364P7778e9994bdt/bbrsNS5cuxcUXX4x58+ZBo9Hgsssuq/OYpaWlOPPMM/Hf//4XU6dOxbJly9CzZ09MnjwZS5cuPWH8hYWF6NevH+bMmdOo1xsIBHDppZciJycHCxYsQG5uLqZOnYrXX38dl1xyCYYOHYr58+fDYrHgpptuwsGDB0/6mFVVVRgzZgxOP/10LFiwAF26dMGUKVMalQgSEVGMKERERDGwatUqBYCyadMmZcWKFYrFYlFcLpeiKIpyzTXXKOeff76iKIrSrVs35bLLLgv93gcffKAAUJ566qmwx7v66qsVQRCUffv2KYqiKFu3blUAKHfffXfY/W644QYFgDJ37txQ2+TJk5WOHTsqFRUVYfedMGGCkpycHIrr4MGDCgBl1apVofvUtt18880nfc0333yzAkB5+umnQ21VVVWKwWBQBEFQ3n777VD7rl276sS5bt06BYCybt26UNvIkSMVAMqiRYtCbV6vVxk8eLCSmZmp+Hy+k8ZFRETRxxEpIiKKuWuvvRZutxtr1qyB3W7HmjVrGpzW95///AcqlQrTp08Pa581axYURcEnn3wSuh+AOvebMWNG2M+KouC9997D2LFjoSgKKioqQv+NHj0aVqsVmzdvbjD23NxcKIoSUUn02267LfTvlJQU9OnTByaTCddee22ovU+fPkhJScGBAwdO+nhqtRp33nln6GetVos777wTZWVl+OmnnxodFxERRY863gEQEVHb16FDB4waNQqrV6+Gy+VCIBDA1VdfXe99f/vtN3Tq1AkWiyWsvV+/fqHba/8viiLy8vLC7tenT5+wn8vLy1FdXY2XX34ZL7/8cr3PWVZW1qTXVR+9Xo8OHTqEtSUnJ6NLly4QBKFOe1VV1Ukfs1OnTjCZTGFtvXv3BhDc++rMM89sZtRERBQpJlJERNQibrjhBtx+++0oKSnBpZdeipSUlBZ5XlmWAQATJ07EzTffXO99Bg0aFLXnU6lUEbUrxxXQICKi1oGJFBERtYgrr7wSd955J77//nu88847Dd6vW7du+O9//wu73R42KrVr167Q7bX/l2UZ+/fvDxuF2r17d9jj1Vb0CwQCGDVqVDRfUospKiqC0+kMG5Xas2cPgODUQyIianlcI0VERC3CbDbjhRdewGOPPYaxY8c2eL8xY8YgEAhgxYoVYe1LliyBIAi49NJLASD0/+eeey7sfsdX4VOpVLjqqqvw3nvvYfv27XWer7y8/IRxR1L+PFYkScJLL70U+tnn8+Gll15Chw4dcPrpp8ctLiKi9owjUkRE1GIamlp3rLFjx+L888/Hn//8Zxw6dAinnnoqPv/8c3z44YeYMWNGaE3U4MGDcf311+P555+H1WrF2WefjbVr12Lfvn11HnPevHlYt24dhg8fjttvvx39+/dHZWUlNm/ejP/+97+orKxsMJ7a8uc333xzRAUnoqlTp06YP38+Dh06hN69e+Odd97B1q1b8fLLL0Oj0cQlJiKi9o6JFBERJRRRFPF///d/ePTRR/HOO+9g1apVyM3NxbPPPotZs2aF3fe1115Dhw4d8Oabb+KDDz7ABRdcgI8//hg5OTlh98vKysIPP/yAJ554Au+//z6ef/55pKen45RTTsH8+fNb8uU1SWpqKt544w1MmzYNr7zyCrKysrBixQrcfvvt8Q6NiKjdEhSuciUiIkpY5513HioqKuqdlkhERPHDNVJEREREREQRYiJFREREREQUISZSREREREREEeIaKSIiIiIioghxRIqIiIiIiChCTKSIiIiIiIgixH2kAMiyjKKiIlgsFgiCEO9wiIiIiIgoThRFgd1uR6dOnSCKDY87MZECUFRUVGfzRiIiIiIiar8KCgrQpUuXBm9nIgXAYrEACL5ZSUlJcY6GiIiIiIjixWazIScnJ5QjNISJFBCazpeUlMREioiIiIiITrrkh8UmiIiIiIiIIsREioiIiIiIKEJMpIiIiIiIiCLERIqIiIiIiChCTKSIiIiIiIgixESKiIiIiIgoQkykiIiIiIiIIsREioiIiIiIKEJMpIiIiIiIiCLERIqIiIiIiChCTKSIiIiIiIgixESKiIiIiIgoQkykiIiIiIiIIsREioiIiIiIKEJMpIiIiIiIiCLERIqIiIiIiChCTKSIiIiIiCguFEXBvn374h1GkzCRIiIiIiKiFrd9+3acf/75GDp0KMrKyuIdTsSYSBERERERUYuxWq2YMWMGBg8ejPXr18NqtWLOnDnxDiti6ngHQEREREREbZ8sy/j73/+OBx54IGwEKi8vD+PHj49jZE3DRIqIiIiIiGJq69atyM/Px3fffRdqMxgMeOihh3DfffdBr9fHMbqmYSJFREREREQxYbVa8dBDD+HFF1+ELMuh9vHjx2Px4sXo1q1bHKNrHiZSREREREQUE4qi4F//+lcoierduzeWL1+Oiy++OM6RNR+LTRARERERUUykpKRgwYIFMJlMmDdvHrZt29YmkiiAiRQREREREUXB0aNHMW3aNBQWFoa1/+lPf8LevXsxe/ZsaLXaOEUXfZzaR0RERERETRYIBPDKK6/gz3/+MyorK3H06FGsXr06dLsoiujYsWMcI4wNjkgREREREVGTbNy4EWeccQamTJmCyspKAMCaNWtQXFwc58hij4kUERERERFFpKysDJMmTcLZZ5+NzZs3h9onTpyI3bt3t8kRqONxah8RERERETWKJEl44YUX8Mgjj8BqtYbaBw4ciJUrV+IPf/hDHKNrWUykiIiIiIioUa644gp8/PHHoZ+Tk5Px5JNPYsqUKVCr21dqwal9RERERETUKDfddFPo37fccgt2796NadOmtbskCuCIFBERERER1cPv98NutyMtLS3Uds011+Cbb77B9ddfj7POOiuO0cUfR6SIiIiIiCjMV199hSFDhmDy5Mlh7YIg4Lnnnmv3SRTARIqIiIiIiGocOXIE119/Pc4//3z8+uuv+OCDD/Dpp5/GO6yExESKiIiIiKid8/l8WLBgAfr27Yu333471D5s2DBkZmbGMbLExTVSRERERETt2BdffIFp06Zh9+7dobb09HTMmzc
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Создаем графики для всех моделей\n",
"for model_name, model_data in class_models.items():\n",
" print(f\"Model: {model_name}\")\n",
" y_pred = model_data[\"preds\"]\n",
" plt.figure(figsize=(10, 6))\n",
" plt.scatter(y_test, y_pred, alpha=0.5)\n",
" plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2)\n",
" plt.xlabel('Фактический уровень глюкозы')\n",
" plt.ylabel('Прогнозируемый уровень глюкозы')\n",
" plt.title(f\"Model: {model_name}\")\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"На представленных графиках можно заметить, что модели в целом не демонстрируют высокого качества. Визуализация их предсказаний показывает сильное рассеивание вокруг идеальной линии y = x, что указывает на значительные отклонения предсказаний от фактических значений.\n",
"\n",
"Тем не менее ориентир, хоть возможно и не столь значительно, каждая из моделей превосходит по всем показателям. Особенно заметное улучшение в \n",
"R2, которая переходит из отрицательного значения в положительное, что говорит о том, что модели хотя бы частично объясняют дисперсию данных. \n",
"\n",
"Кроме того, можно сказать, что все модели имеет умеренную дисперсию и не сильно подвержены переобучению, потому что разница между RMSE на обучении и тесте незначительна.\n",
"\n",
"Итоговые выводы:\n",
"- Наиболее качественная модель: MLP, так как она показывает наименьшее значение RMSE и наибольшее значение R2, что указывает на лучшую точность и объяснение дисперсии целевой переменной.\n",
"\n",
"- Random Forest: Близок по производительности к MLP, с чуть большим RMSE, но является более устойчивой моделью с небольшими отклонениями между обучением и тестом.\n",
"\n",
"- KNN: Худшая модель, демонстрирующая наибольшие ошибки и низкое R2, что указывает на необходимость улучшения или использования другой модели для данной задачи."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "aimenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}