2024-12-21 04:54:14 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Данные по инсультам\n",
"\n",
"Выведем информацию о столбцах датасета:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>9046</th>\n",
" <td>Male</td>\n",
" <td>67.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>228.69</td>\n",
" <td>36.6</td>\n",
" <td>formerly smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>51676</th>\n",
" <td>Female</td>\n",
" <td>61.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>202.21</td>\n",
" <td>NaN</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31112</th>\n",
" <td>Male</td>\n",
" <td>80.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>105.92</td>\n",
" <td>32.5</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>60182</th>\n",
" <td>Female</td>\n",
" <td>49.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>171.23</td>\n",
" <td>34.4</td>\n",
" <td>smokes</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1665</th>\n",
" <td>Female</td>\n",
" <td>79.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>174.12</td>\n",
" <td>24.0</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18234</th>\n",
" <td>Female</td>\n",
" <td>80.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>83.75</td>\n",
" <td>NaN</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>44873</th>\n",
" <td>Female</td>\n",
" <td>81.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>125.20</td>\n",
" <td>40.0</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19723</th>\n",
" <td>Female</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>82.99</td>\n",
" <td>30.6</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>37544</th>\n",
" <td>Male</td>\n",
" <td>51.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>166.29</td>\n",
" <td>25.6</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>44679</th>\n",
" <td>Female</td>\n",
" <td>44.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Govt_job</td>\n",
" <td>Urban</td>\n",
" <td>85.28</td>\n",
" <td>26.2</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5110 rows × 11 columns</p>\n",
"</div>"
],
"text/plain": [
" gender age hypertension heart_disease ever_married work_type \\\n",
"id \n",
"9046 Male 67.0 0 1 Yes Private \n",
"51676 Female 61.0 0 0 Yes Self-employed \n",
"31112 Male 80.0 0 1 Yes Private \n",
"60182 Female 49.0 0 0 Yes Private \n",
"1665 Female 79.0 1 0 Yes Self-employed \n",
"... ... ... ... ... ... ... \n",
"18234 Female 80.0 1 0 Yes Private \n",
"44873 Female 81.0 0 0 Yes Self-employed \n",
"19723 Female 35.0 0 0 Yes Self-employed \n",
"37544 Male 51.0 0 0 Yes Private \n",
"44679 Female 44.0 0 0 Yes Govt_job \n",
"\n",
" Residence_type avg_glucose_level bmi smoking_status stroke \n",
"id \n",
"9046 Urban 228.69 36.6 formerly smoked 1 \n",
"51676 Rural 202.21 NaN never smoked 1 \n",
"31112 Rural 105.92 32.5 never smoked 1 \n",
"60182 Urban 171.23 34.4 smokes 1 \n",
"1665 Rural 174.12 24.0 never smoked 1 \n",
"... ... ... ... ... ... \n",
"18234 Urban 83.75 NaN never smoked 0 \n",
"44873 Urban 125.20 40.0 never smoked 0 \n",
"19723 Rural 82.99 30.6 never smoked 0 \n",
"37544 Rural 166.29 25.6 formerly smoked 0 \n",
"44679 Urban 85.28 26.2 Unknown 0 \n",
"\n",
"[5110 rows x 11 columns]"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"random_state=9\n",
"\n",
"df = pd.read_csv(\"./csv/option4.csv\", index_col=\"id\")\n",
"\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Бизнес-цели\n",
"#### Классификация\n",
"##### Цель: сделать модель, которая на основе данных про здоровье, образ жизни и соцдем-факторы будет предсказывать риск инсульта.\n",
"\n",
"Применение:\n",
"Для врачей: чтобы находить пациентов с высоким риском и вовремя их спасать.\n",
"Для медсистем: можно встроить в карты пациентов, чтобы система сама предупреждала врача.\n",
"Для людей: повышать осведомленность о факторах риска и как их избежать.\n",
"\n",
"#### Регрессия\n",
"##### Цель: сделать модель, которая будет предсказывать уровень глюкозы на основе тех же факторов. Поможет отслеживать изменения и оценивать риски.\n",
"\n",
"Применение:\n",
"Для врачей: находить пациентов с риском диабета и другими проблемами, чтобы сразу назначать профилактику.\n",
"Для медсистем: встроить в записи пациентов, чтобы врачи видели уровень глюкозы даже без лабораторий.\n",
"Для населения: обучать, как lifestyle влияет на глюкозу, и рекомендовать изменения.\n",
"\n",
"### Качество моделей\n",
"#### Классификация (инсульт):\n",
"\n",
"Признаки норм: возраст, гипертония, сердце — инфа полезная, модель будет неплохо различать риск.\n",
"Проблемы: нет данных про гены, детальную историю болезней, питание, физнагрузки. Это режет потолок точности.\n",
"\n",
"#### Регрессия (глюкоза):\n",
"Признаки норм: возраст, курение, давление — связь с глюкозой есть, что-то предскажет.\n",
"Проблемы: не хватает данных про еду, активность, гормоны, поэтому точность ограничена."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Классификация\n",
"\n",
"Разделим набор данных на на обучающую и тестовые выборки (80/20). Целевой признак - stroke"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from typing import Tuple\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \"\"\"\n",
" Splits a Pandas dataframe into three subsets (train, val, and test)\n",
" following fractional ratios provided by the user, where each subset is\n",
" stratified by the values in a specific column (that is, each subset has\n",
" the same relative frequency of the values in the column). It performs this\n",
" splitting by running train_test_split() twice.\n",
"\n",
" Parameters\n",
" ----------\n",
" df_input : Pandas dataframe\n",
" Input dataframe to be split.\n",
" stratify_colname : str\n",
" The name of the column that will be used for stratification. Usually\n",
" this column would be for the label.\n",
" frac_train : float\n",
" frac_val : float\n",
" frac_test : float\n",
" The ratios with which the dataframe will be split into train, val, and\n",
" test data. The values should be expressed as float fractions and should\n",
" sum to 1.0.\n",
" random_state : int, None, or RandomStateInstance\n",
" Value to be passed to train_test_split().\n",
"\n",
" Returns\n",
" -------\n",
" df_train, df_val, df_test :\n",
" Dataframes containing the three splits.\n",
" \"\"\"\n",
"\n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
"\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
"\n",
" X = df_input # Contains all columns.\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
"\n",
" # Split original dataframe into train and temp dataframes.\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
"\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
"\n",
" # Split the temp dataframe into val and test dataframes.\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
"\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>22159</th>\n",
" <td>Female</td>\n",
" <td>54.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>97.06</td>\n",
" <td>28.5</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8920</th>\n",
" <td>Female</td>\n",
" <td>51.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>76.35</td>\n",
" <td>33.5</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>65507</th>\n",
" <td>Male</td>\n",
" <td>33.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>55.72</td>\n",
" <td>38.2</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>43196</th>\n",
" <td>Female</td>\n",
" <td>52.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>59.54</td>\n",
" <td>42.2</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>59745</th>\n",
" <td>Female</td>\n",
" <td>27.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>76.74</td>\n",
" <td>53.9</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>66546</th>\n",
" <td>Female</td>\n",
" <td>20.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>80.08</td>\n",
" <td>25.1</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68798</th>\n",
" <td>Female</td>\n",
" <td>58.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>59.86</td>\n",
" <td>28.0</td>\n",
" <td>formerly smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61409</th>\n",
" <td>Male</td>\n",
" <td>32.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Govt_job</td>\n",
" <td>Urban</td>\n",
" <td>58.24</td>\n",
" <td>NaN</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69259</th>\n",
" <td>Female</td>\n",
" <td>77.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>100.85</td>\n",
" <td>29.5</td>\n",
" <td>smokes</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17231</th>\n",
" <td>Female</td>\n",
" <td>24.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>90.42</td>\n",
" <td>24.3</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4088 rows × 11 columns</p>\n",
"</div>"
],
"text/plain": [
" gender age hypertension heart_disease ever_married work_type \\\n",
"id \n",
"22159 Female 54.0 1 0 No Private \n",
"8920 Female 51.0 0 0 Yes Self-employed \n",
"65507 Male 33.0 0 0 Yes Private \n",
"43196 Female 52.0 0 0 Yes Self-employed \n",
"59745 Female 27.0 0 0 Yes Private \n",
"... ... ... ... ... ... ... \n",
"66546 Female 20.0 0 0 No Private \n",
"68798 Female 58.0 0 0 Yes Private \n",
"61409 Male 32.0 1 0 No Govt_job \n",
"69259 Female 77.0 0 0 Yes Private \n",
"17231 Female 24.0 0 0 No Private \n",
"\n",
" Residence_type avg_glucose_level bmi smoking_status stroke \n",
"id \n",
"22159 Urban 97.06 28.5 formerly smoked 0 \n",
"8920 Rural 76.35 33.5 formerly smoked 0 \n",
"65507 Rural 55.72 38.2 never smoked 0 \n",
"43196 Urban 59.54 42.2 Unknown 0 \n",
"59745 Urban 76.74 53.9 Unknown 0 \n",
"... ... ... ... ... ... \n",
"66546 Urban 80.08 25.1 never smoked 0 \n",
"68798 Rural 59.86 28.0 formerly smoked 1 \n",
"61409 Urban 58.24 NaN formerly smoked 0 \n",
"69259 Rural 100.85 29.5 smokes 0 \n",
"17231 Urban 90.42 24.3 never smoked 0 \n",
"\n",
"[4088 rows x 11 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>22159</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8920</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>65507</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>43196</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>59745</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>66546</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68798</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61409</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69259</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17231</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4088 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" stroke\n",
"id \n",
"22159 0\n",
"8920 0\n",
"65507 0\n",
"43196 0\n",
"59745 0\n",
"... ...\n",
"66546 0\n",
"68798 1\n",
"61409 0\n",
"69259 0\n",
"17231 0\n",
"\n",
"[4088 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>18072</th>\n",
" <td>Female</td>\n",
" <td>39.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Govt_job</td>\n",
" <td>Urban</td>\n",
" <td>107.47</td>\n",
" <td>21.3</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>67063</th>\n",
" <td>Male</td>\n",
" <td>62.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>130.56</td>\n",
" <td>36.1</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40387</th>\n",
" <td>Female</td>\n",
" <td>17.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>77.46</td>\n",
" <td>24.0</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18032</th>\n",
" <td>Male</td>\n",
" <td>62.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>90.61</td>\n",
" <td>25.8</td>\n",
" <td>smokes</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5478</th>\n",
" <td>Female</td>\n",
" <td>60.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>203.04</td>\n",
" <td>NaN</td>\n",
" <td>smokes</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>57710</th>\n",
" <td>Female</td>\n",
" <td>50.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>112.25</td>\n",
" <td>21.6</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63043</th>\n",
" <td>Female</td>\n",
" <td>27.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>61.80</td>\n",
" <td>26.8</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63986</th>\n",
" <td>Male</td>\n",
" <td>60.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>153.48</td>\n",
" <td>37.3</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28461</th>\n",
" <td>Male</td>\n",
" <td>15.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Never_worked</td>\n",
" <td>Rural</td>\n",
" <td>79.59</td>\n",
" <td>28.4</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>54975</th>\n",
" <td>Male</td>\n",
" <td>7.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>64.06</td>\n",
" <td>18.9</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1022 rows × 11 columns</p>\n",
"</div>"
],
"text/plain": [
" gender age hypertension heart_disease ever_married work_type \\\n",
"id \n",
"18072 Female 39.0 0 0 Yes Govt_job \n",
"67063 Male 62.0 0 0 Yes Self-employed \n",
"40387 Female 17.0 0 0 No Private \n",
"18032 Male 62.0 0 1 Yes Private \n",
"5478 Female 60.0 0 0 Yes Self-employed \n",
"... ... ... ... ... ... ... \n",
"57710 Female 50.0 0 0 Yes Private \n",
"63043 Female 27.0 0 0 No Private \n",
"63986 Male 60.0 0 0 Yes Private \n",
"28461 Male 15.0 0 0 No Never_worked \n",
"54975 Male 7.0 0 0 No Self-employed \n",
"\n",
" Residence_type avg_glucose_level bmi smoking_status stroke \n",
"id \n",
"18072 Urban 107.47 21.3 Unknown 0 \n",
"67063 Urban 130.56 36.1 Unknown 0 \n",
"40387 Rural 77.46 24.0 Unknown 0 \n",
"18032 Rural 90.61 25.8 smokes 0 \n",
"5478 Urban 203.04 NaN smokes 0 \n",
"... ... ... ... ... ... \n",
"57710 Rural 112.25 21.6 Unknown 0 \n",
"63043 Urban 61.80 26.8 formerly smoked 0 \n",
"63986 Rural 153.48 37.3 never smoked 0 \n",
"28461 Rural 79.59 28.4 Unknown 0 \n",
"54975 Rural 64.06 18.9 Unknown 0 \n",
"\n",
"[1022 rows x 11 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>18072</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>67063</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40387</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18032</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5478</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>57710</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63043</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63986</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28461</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>54975</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1022 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" stroke\n",
"id \n",
"18072 0\n",
"67063 0\n",
"40387 0\n",
"18032 0\n",
"5478 0\n",
"... ...\n",
"57710 0\n",
"63043 0\n",
"63986 0\n",
"28461 0\n",
"54975 0\n",
"\n",
"[1022 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"stroke\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выберем ориентир для задачи классификации. Для этого применим алгоритм случайного предсказания, т.е . в каждом случае в качестве предсказания выберем случайный класс."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Baseline Accuracy: 0.5\n",
"Baseline Precision: 0.05758157389635317\n",
"Baseline Recall: 0.6\n",
"Baseline F1 Score: 0.10507880910683012\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score\n",
"\n",
"# Получаем уникальные классы для целевого признака из тренировочного набора данных\n",
"unique_classes = np.unique(y_train)\n",
"\n",
"# Генерируем случайные предсказания, выбирая случайное значение из области значений целевого признака\n",
"random_predictions = np.random.choice(unique_classes, size=len(y_test))\n",
"\n",
"# Вычисление метрик для ориентира\n",
"baseline_accuracy = accuracy_score(y_test, random_predictions)\n",
"baseline_precision = precision_score(y_test, random_predictions)\n",
"baseline_recall = recall_score(y_test, random_predictions)\n",
"baseline_f1 = f1_score(y_test, random_predictions)\n",
"\n",
"print('Baseline Accuracy:', baseline_accuracy)\n",
"print('Baseline Precision:', baseline_precision)\n",
"print('Baseline Recall:', baseline_recall)\n",
"print('Baseline F1 Score:', baseline_f1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Метрики модели\n",
"- Accuracy: доля правильных предсказаний от общего числа примеров. Простая, но бесполезная метрика в задачах с дисбалансом классов — не учитывает, как модель работает с редким классом.\n",
"- Precision: доля правильных предсказаний положительного класса среди всех предсказанных положительных. Полезна, если критичны ложные срабатывания (например, чтобы не ошибаться с инсультом).\n",
"- Recall: доля найденных объектов положительного класса среди всех реальных примеров положительного класса. Помогает понять, насколько хорошо модель \"ловит\" положительный класс. Важна, чтобы минимизировать пропуски инсультов.\n",
"- F1 Score: гармоническое среднее между precision и recall. Учитывает и точность, и полноту, что важно в задачах с несбалансированными классами.\n",
"Эти метрики показывают разные аспекты работы модели: от общего уровня точности до способности находить редкие классы и балансировать между precision и recall. Это позволяет оценить модель всесторонне."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Сформируем конвейер для классификации"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"columns_to_drop = [\"work_type\", \"stroke\"]\n",
"columns_not_to_modify = [\"hypertension\", \"heart_disease\"]\n",
"\n",
"num_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop\n",
" and column not in columns_not_to_modify\n",
" and df[column].dtype != \"object\"\n",
"]\n",
"\n",
"cat_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop\n",
" and column not in columns_not_to_modify\n",
" and df[column].dtype == \"object\"\n",
"]\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Теперь проверим работу конвейера:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>gender_Male</th>\n",
" <th>gender_Other</th>\n",
" <th>ever_married_Yes</th>\n",
" <th>Residence_type_Urban</th>\n",
" <th>smoking_status_formerly smoked</th>\n",
" <th>smoking_status_never smoked</th>\n",
" <th>smoking_status_smokes</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>22159</th>\n",
" <td>0.472344</td>\n",
" <td>-0.194427</td>\n",
" <td>-0.059214</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8920</th>\n",
" <td>0.339807</td>\n",
" <td>-0.653763</td>\n",
" <td>0.587887</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>65507</th>\n",
" <td>-0.455418</td>\n",
" <td>-1.111325</td>\n",
" <td>1.196162</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>43196</th>\n",
" <td>0.383986</td>\n",
" <td>-1.026600</td>\n",
" <td>1.713843</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>59745</th>\n",
" <td>-0.720492</td>\n",
" <td>-0.645113</td>\n",
" <td>3.228060</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>66546</th>\n",
" <td>-1.029746</td>\n",
" <td>-0.571034</td>\n",
" <td>-0.499243</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68798</th>\n",
" <td>0.649060</td>\n",
" <td>-1.019502</td>\n",
" <td>-0.123924</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61409</th>\n",
" <td>-0.499597</td>\n",
" <td>-1.055433</td>\n",
" <td>-0.098040</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69259</th>\n",
" <td>1.488464</td>\n",
" <td>-0.110367</td>\n",
" <td>0.070206</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17231</th>\n",
" <td>-0.853030</td>\n",
" <td>-0.341699</td>\n",
" <td>-0.602779</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4088 rows × 12 columns</p>\n",
"</div>"
],
"text/plain": [
" age avg_glucose_level bmi gender_Male gender_Other \\\n",
"id \n",
"22159 0.472344 -0.194427 -0.059214 0.0 0.0 \n",
"8920 0.339807 -0.653763 0.587887 0.0 0.0 \n",
"65507 -0.455418 -1.111325 1.196162 1.0 0.0 \n",
"43196 0.383986 -1.026600 1.713843 0.0 0.0 \n",
"59745 -0.720492 -0.645113 3.228060 0.0 0.0 \n",
"... ... ... ... ... ... \n",
"66546 -1.029746 -0.571034 -0.499243 0.0 0.0 \n",
"68798 0.649060 -1.019502 -0.123924 0.0 0.0 \n",
"61409 -0.499597 -1.055433 -0.098040 1.0 0.0 \n",
"69259 1.488464 -0.110367 0.070206 0.0 0.0 \n",
"17231 -0.853030 -0.341699 -0.602779 0.0 0.0 \n",
"\n",
" ever_married_Yes Residence_type_Urban smoking_status_formerly smoked \\\n",
"id \n",
"22159 0.0 1.0 1.0 \n",
"8920 1.0 0.0 1.0 \n",
"65507 1.0 0.0 0.0 \n",
"43196 1.0 1.0 0.0 \n",
"59745 1.0 1.0 0.0 \n",
"... ... ... ... \n",
"66546 0.0 1.0 0.0 \n",
"68798 1.0 0.0 1.0 \n",
"61409 0.0 1.0 1.0 \n",
"69259 1.0 0.0 0.0 \n",
"17231 0.0 1.0 0.0 \n",
"\n",
" smoking_status_never smoked smoking_status_smokes hypertension \\\n",
"id \n",
"22159 0.0 0.0 1 \n",
"8920 0.0 0.0 0 \n",
"65507 1.0 0.0 0 \n",
"43196 0.0 0.0 0 \n",
"59745 0.0 0.0 0 \n",
"... ... ... ... \n",
"66546 1.0 0.0 0 \n",
"68798 0.0 0.0 0 \n",
"61409 0.0 0.0 1 \n",
"69259 0.0 1.0 0 \n",
"17231 1.0 0.0 0 \n",
"\n",
" heart_disease \n",
"id \n",
"22159 0 \n",
"8920 0 \n",
"65507 0 \n",
"43196 0 \n",
"59745 0 \n",
"... ... \n",
"66546 0 \n",
"68798 0 \n",
"61409 0 \n",
"69259 0 \n",
"17231 0 \n",
"\n",
"[4088 rows x 12 columns]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подберем оптимальные гиперпараметры для каждой из выбранных моделей методом поиска по сетке и сформируем их набор.\n",
"\n",
"knn -- k-ближайших соседей\n",
"\n",
"random_forest -- метод случайного леса (набор деревьев решений)\n",
"\n",
"mlp -- многослойный персептрон (нейронная сеть)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\code\\mai\\labs\\AIM-PIbd-31-Bakalskaya-E-D\\lab_4\\venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие параметры для knn: {'n_neighbors': 1, 'weights': 'uniform'}\n",
"Лучшие параметры для random_forest: {'class_weight': 'balanced_subsample', 'criterion': 'entropy', 'max_depth': 7, 'max_features': 'sqrt', 'n_estimators': 50, 'random_state': 9}\n",
"Лучшие параметры для mlp: {'alpha': np.float64(0.1), 'early_stopping': True, 'hidden_layer_sizes': np.int64(14), 'max_iter': 1000, 'random_state': 9, 'solver': 'adam'}\n"
]
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn import neighbors, ensemble, neural_network\n",
"\n",
"# Словарь с вариантами гиперпараметров для каждой модели\n",
"param_grids = {\n",
" \"knn\": {\n",
" \"n_neighbors\": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], \n",
" \"weights\": ['uniform', 'distance']\n",
" },\n",
" \"random_forest\": {\n",
" \"n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" \"max_features\": [\"sqrt\", \"log2\", 2],\n",
" \"max_depth\": [2, 3, 4, 5, 6, 7, 8, 9, 10],\n",
" \"criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
" \"random_state\": [random_state],\n",
" \"class_weight\": [\"balanced\", \"balanced_subsample\"]\n",
" },\n",
" \"mlp\": {\n",
" \"solver\": ['adam'], \n",
" \"max_iter\": [1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000], \n",
" \"alpha\": 10.0 ** -np.arange(1, 10), \n",
" \"hidden_layer_sizes\":np.arange(10, 15), \n",
" \"early_stopping\": [True, False],\n",
" \"random_state\": [random_state]\n",
" }\n",
"}\n",
"\n",
"# Создаем экземпляры моделей\n",
"models = {\n",
" \"knn\": neighbors.KNeighborsClassifier(),\n",
" \"random_forest\": ensemble.RandomForestClassifier(),\n",
" \"mlp\": neural_network.MLPClassifier()\n",
"}\n",
"\n",
"# Словарь для хранения моделей с их лучшими параметрами\n",
"class_models = {}\n",
"\n",
"# Выполнение поиска по сетке для каждой модели\n",
"for model_name, model in models.items():\n",
" # Создаем GridSearchCV для текущей модели\n",
" gs_optimizer = GridSearchCV(estimator=model, param_grid=param_grids[model_name], scoring=\"f1\", n_jobs=-1)\n",
" \n",
" # Обучаем GridSearchCV\n",
" gs_optimizer.fit(preprocessed_df, y_train.values.ravel())\n",
" \n",
" # Получаем лучшие параметры\n",
" best_params = gs_optimizer.best_params_\n",
" print(f\"Лучшие параметры для {model_name}: {best_params}\")\n",
" \n",
" class_models[model_name] = {\n",
" \"model\": model.set_params(**best_params) # Настраиваем модель с лучшими параметрами\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### ЖЕСТЬ ЭТА Х Р Е Н Ь 12 МИНУТ СОЗДАВАЛАСЬ..."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"что ж теперь... обучим модели разными способами модели и посмотрим на качество обучения"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: knn\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"from sklearn.metrics import confusion_matrix\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = accuracy_score(\n",
" y_test, y_test_predict\n",
" ) \n",
" class_models[model_name][\"F1_train\"] = f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"Confusion_matrix\"] = confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Матрицы неточностей:"
]
},
{
"cell_type": "code",
2024-12-21 05:10:04 +04:00
"execution_count": 19,
2024-12-21 04:54:14 +04:00
"metadata": {},
"outputs": [
2024-12-21 05:10:04 +04:00
{
"ename": "KeyError",
"evalue": "'Confusion_matrix'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[19], line 19\u001b[0m\n\u001b[0;32m 16\u001b[0m fig, ax \u001b[38;5;241m=\u001b[39m plt\u001b[38;5;241m.\u001b[39msubplots(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m2\u001b[39m, figsize\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m12\u001b[39m, \u001b[38;5;241m10\u001b[39m))\n\u001b[0;32m 18\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m index, (key, model_info) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(class_models\u001b[38;5;241m.\u001b[39mitems()):\n\u001b[1;32m---> 19\u001b[0m c_matrix \u001b[38;5;241m=\u001b[39m \u001b[43mmodel_info\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mConfusion_matrix\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[0;32m 21\u001b[0m \u001b[38;5;66;03m# Генерация кастомных цветов\u001b[39;00m\n\u001b[0;32m 22\u001b[0m disp \u001b[38;5;241m=\u001b[39m ConfusionMatrixDisplay(\n\u001b[0;32m 23\u001b[0m confusion_matrix\u001b[38;5;241m=\u001b[39mc_matrix, display_labels\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNot stroke\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mStroke\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m 24\u001b[0m )\n",
"\u001b[1;31mKeyError\u001b[0m: 'Confusion_matrix'"
]
},
2024-12-21 04:54:14 +04:00
{
"data": {
2024-12-21 05:10:04 +04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA+AAAAMzCAYAAAAmjXj8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABB4UlEQVR4nO3db2xd5X3A8Z/t4GtQsQnLYieZaQYdpS2Q0IR4hiLE5NUSKF1eTM2gSrKIP6PNEI21lYRAXEobZwxQpGIakcLoi7KkRYCqJjKjXqOK4ilqEkt0JCAaaLKqNsk67My0NrHPXiDcmTg015z72Amfj3Rf5HCO73MfOfzy9b2+tyzLsiwAAACAkiqf7AUAAADAh4EABwAAgAQEOAAAACQgwAEAACABAQ4AAAAJCHAAAABIQIADAABAAgIcAAAAEhDgAAAAkIAABwAAgASKDvCf/OQnsXjx4pg9e3aUlZXFM8888wev2blzZ3z605+OQqEQH/vYx+Lxxx+fwFIBgBTMegAojaIDfGBgIObNmxft7e0ndf5rr70W1113XVxzzTXR3d0dX/7yl+Omm26KZ599tujFAgClZ9YDQGmUZVmWTfjisrJ4+umnY8mSJSc854477ojt27fHz3/+89Fjf/M3fxNvvvlmdHR0TPSuAYAEzHoAyM+0Ut9BV1dXNDU1jTnW3NwcX/7yl094zeDgYAwODo7+eWRkJH7zm9/EH/3RH0VZWVmplgoAJyXLsjh69GjMnj07ysu9nYpZD8DpqBTzvuQB3tPTE7W1tWOO1dbWRn9/f/z2t7+NM88887hr2tra4p577in10gDgAzl06FD8yZ/8yWQvY9KZ9QCczvKc9yUP8IlYu3ZttLS0jP65r68vzjvvvDh06FBUV1dP4soAIKK/vz/q6+vj7LPPnuylnLLMegCmulLM+5IHeF1dXfT29o451tvbG9XV1eP+RDwiolAoRKFQOO54dXW1oQzAlOGl0u8w6wE4neU570v+i2uNjY3R2dk55thzzz0XjY2Npb5rACABsx4ATk7RAf6///u/0d3dHd3d3RHxzkePdHd3x8GDByPinZeULV++fPT8W2+9NQ4cOBBf+cpXYv/+/fHwww/H9773vVi9enU+jwAAyJVZDwClUXSA/+xnP4vLLrssLrvssoiIaGlpicsuuyzWr18fERG//vWvRwd0RMSf/umfxvbt2+O5556LefPmxQMPPBDf/va3o7m5OaeHAADkyawHgNL4QJ8Dnkp/f3/U1NREX1+f3wsDYNKZS/mzpwBMNaWYTT68FAAAABIQ4AAAAJCAAAcAAIAEBDgAAAAkIMABAAAgAQEOAAAACQhwAAAASECAAwAAQAICHAAAABIQ4AAAAJCAAAcAAIAEBDgAAAAkIMABAAAgAQEOAAAACQhwAAAASECAAwAAQAICHAAAABIQ4AAAAJCAAAcAAIAEBDgAAAAkIMABAAAgAQEOAAAACQhwAAAASECAAwAAQAICHAAAABIQ4AAAAJCAAAcAAIAEBDgAAAAkIMABAAAgAQEOAAAACQhwAAAASECAAwAAQAICHAAAABIQ4AAAAJCAAAcAAIAEBDgAAAAkIMABAAAgAQEOAAAACQhwAAAASECAAwAAQAICHAAAABIQ4AAAAJCAAAcAAIAEBDgAAAAkIMABAAAgAQEOAAAACQhwAAAASECAAwAAQAICHAAAABIQ4AAAAJCAAAcAAIAEBDgAAAAkIMABAAAgAQEOAAAACQhwAAAASECAAwAAQAICHAAAABIQ4AAAAJCAAAcAAIAEJhTg7e3tMXfu3KiqqoqGhobYtWvX+56/adOm+PjHPx5nnnlm1NfXx+rVq+N3v/vdhBYMAJSeWQ8A+Ss6wLdt2xYtLS3R2toae/bsiXnz5kVzc3O88cYb457/xBNPxJo1a6K1tTX27dsXjz76aGzbti3uvPPOD7x4ACB/Zj0AlEbRAf7ggw/GzTffHCtXroxPfvKTsXnz5jjrrLPiscceG/f8F154Ia688sq44YYbYu7cufHZz342rr/++j/4k3QAYHKY9QBQGkUF+NDQUOzevTuampp+/wXKy6OpqSm6urrGveaKK66I3bt3jw7hAwcOxI4dO+Laa6894f0MDg5Gf3//mBsAUHpmPQCUzrRiTj5y5EgMDw9HbW3tmOO1tbWxf//+ca+54YYb4siRI/GZz3wmsiyLY8eOxa233vq+L0tra2uLe+65p5ilAQA5MOsBoHRK/i7oO3fujA0bNsTDDz8ce/bsiaeeeiq2b98e99577wmvWbt2bfT19Y3eDh06VOplAgATZNYDwMkp6hnwGTNmREVFRfT29o453tvbG3V1deNec/fdd8eyZcvipptuioiISy65JAYGBuKWW26JdevWRXn58T8DKBQKUSgUilkaAJADsx4ASqeoZ8ArKytjwYIF0dnZOXpsZGQkOjs7o7Gxcdxr3nrrreMGb0VFRUREZFlW7HoBgBIy6wGgdIp6BjwioqWlJVasWBELFy6MRYsWxaZNm2JgYCBWrlwZERHLly+POXPmRFtbW0RELF68OB588MG47LLLoqGhIV599dW4++67Y/HixaPDGQCYOsx6ACiNogN86dKlcfjw4Vi/fn309PTE/Pnzo6OjY/TNWg4ePDjmp+B33XVXlJWVxV133RW/+tWv4o//+I9j8eLF8Y1vfCO/RwEA5MasB4DSKMtOgdeG9ff3R01NTfT19UV1dfVkLweADzlzKX/2FICpphSzqeTvgg4AAAAIcAAAAEhCgAMAAEACAhwAAAASEOAAAACQgAAHAACABAQ4AAAAJCDAAQAAIAEBDgAAAAkIcAAAAEhAgAMAAEACAhwAAAASEOAAAACQgAAHAACABAQ4AAAAJCDAAQAAIAEBDgAAAAkIcAAAAEhAgAMAAEACAhwAAAASEOAAAACQgAAHAACABAQ4AAAAJCDAAQAAIAEBDgAAAAkIcAAAAEhAgAMAAEACAhwAAAASEOAAAACQgAAHAACABAQ4AAAAJCDAAQAAIAEBDgAAAAkIcAAAAEhAgAMAAEACAhwAAAASEOAAAACQgAAHAACABAQ4AAAAJCDAAQAAIAEBDgAAAAkIcAAAAEhAgAMAAEACAhwAAAASEOAAAACQgAAHAACABAQ4AAAAJCDAAQAAIAEBDgAAAAkIcAAAAEhAgAMAAEACAhwAAAASEOAAAACQgAAHAACABAQ4AAAAJCDAAQAAIAEBDgAAAAkIcAAAAEhAgAMAAEACEwrw9vb2mDt3blRVVUVDQ0Ps2rXrfc9/8803Y9WqVTFr1qwoFApx4YUXxo4dOya0YACg9Mx6AMjftGIv2LZtW7S0tMTmzZujoaEhNm3aFM3NzfHyyy/HzJkzjzt/aGgo/vIv/zJmzpwZTz75ZMyZMyd++ctfxjnnnJPH+gGAnJn1AFAaZVmWZcVc0NDQEJdffnk89NBDERExMjIS9fX1cdttt8WaNWuOO3/z5s3xz//8z7F///4444wzJrTI/v7+qKmpib6+vqiurp7Q1wCAvJzuc8msB4DSzKaiXoI+NDQUu3fvjqampt9/gfLyaGpqiq6urnGv+cEPfhCNjY2xatWqqK2tjYsvvjg2bNgQw8PDJ7yfwcHB6O/vH3MDAErPrAeA0ikqwI8cORLDw8NRW1s75nhtbW309PSMe82BAwfiySefjOHh4dixY0fcfffd8cADD8TXv/71E95PW1tb1NTUjN7q6+uLWSYAMEFmPQCUTsnfBX1kZCRmzpwZjzzySCxYsCCWLl0a69ati82bN5/wmrVr10ZfX9/o7dChQ6VeJgAwQWY9AJycot6EbcaMGVFRURG9vb1jjvf29kZdXd2418yaNSvOOOOMqKioGD32iU98Inp6emJoaCgqKyuPu6ZQKEShUChmaQBADsx6ACidop4Br6ysjAULFkRnZ+fosZGRkejs7IzGxsZxr7nyyivj1VdfjZGRkdFjr7zySsyaNWvcgQwATB6zHgBKp+iXoLe0tMSWLVviO9/5Tuzbty+++MUvxsDAQKxcuTIiIpYvXx5r164dPf+LX/xi/OY
2024-12-21 04:54:14 +04:00
"text/plain": [
2024-12-21 05:10:04 +04:00
"<Figure size 1200x1000 with 4 Axes>"
2024-12-21 04:54:14 +04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Цветовая схема: светло-голубой для правильных, бледно-красный для ошибок\n",
"def custom_color_map(c_matrix):\n",
" colors = np.empty_like(c_matrix, dtype=str)\n",
" for i in range(c_matrix.shape[0]):\n",
" for j in range(c_matrix.shape[1]):\n",
" if i == j:\n",
" colors[i, j] = \"#add8e6\" # Светло-голубой\n",
" else:\n",
" colors[i, j] = \"#f08080\" # Бледно-красный\n",
" return colors\n",
"\n",
"fig, ax = plt.subplots(2, 2, figsize=(12, 10))\n",
"\n",
"for index, (key, model_info) in enumerate(class_models.items()):\n",
" c_matrix = model_info[\"Confusion_matrix\"]\n",
"\n",
" # Генерация кастомных цветов\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Not stroke\", \"Stroke\"]\n",
" )\n",
" disp.plot(ax=ax.flat[index], cmap=custom_color_map(c_matrix), colorbar=False)\n",
"\n",
" disp.ax_.set_title(key)\n",
"\n",
"if len(class_models) < len(ax.flat):\n",
" for i in range(len(class_models), len(ax.flat)):\n",
" fig.delaxes(ax.flat[i])\n",
"\n",
"plt.subplots_adjust(top=0.9, bottom=0.1, hspace=0.4, wspace=0.3)\n",
"plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Precision, Recall, Accuracy, F1:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_0e668_row0_col0 {\n",
" background-color: #1f988b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0e668_row0_col1, #T_0e668_row1_col0, #T_0e668_row1_col2, #T_0e668_row2_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_0e668_row0_col2, #T_0e668_row0_col3, #T_0e668_row1_col1, #T_0e668_row2_col0 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0e668_row0_col4 {\n",
" background-color: #b7318a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0e668_row0_col5, #T_0e668_row1_col4, #T_0e668_row1_col6, #T_0e668_row2_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0e668_row0_col6, #T_0e668_row0_col7, #T_0e668_row2_col4, #T_0e668_row2_col5 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0e668_row1_col3, #T_0e668_row2_col1 {\n",
" background-color: #1f968b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0e668_row1_col5 {\n",
" background-color: #be3885;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0e668_row1_col7 {\n",
" background-color: #9c179e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0e668_row2_col2 {\n",
" background-color: #86d549;\n",
" color: #000000;\n",
"}\n",
"#T_0e668_row2_col6 {\n",
" background-color: #8808a6;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_0e668\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" > </th>\n",
" <th id=\"T_0e668_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_0e668_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_0e668_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_0e668_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_0e668_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_0e668_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_0e668_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_0e668_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_0e668_level0_row0\" class=\"row_heading level0 row0\" >mlp</th>\n",
" <td id=\"T_0e668_row0_col0\" class=\"data row0 col0\" >0.400000</td>\n",
" <td id=\"T_0e668_row0_col1\" class=\"data row0 col1\" >0.200000</td>\n",
" <td id=\"T_0e668_row0_col2\" class=\"data row0 col2\" >0.020101</td>\n",
" <td id=\"T_0e668_row0_col3\" class=\"data row0 col3\" >0.020000</td>\n",
" <td id=\"T_0e668_row0_col4\" class=\"data row0 col4\" >0.950832</td>\n",
" <td id=\"T_0e668_row0_col5\" class=\"data row0 col5\" >0.948141</td>\n",
" <td id=\"T_0e668_row0_col6\" class=\"data row0 col6\" >0.038278</td>\n",
" <td id=\"T_0e668_row0_col7\" class=\"data row0 col7\" >0.036364</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0e668_level0_row1\" class=\"row_heading level0 row1\" >knn</th>\n",
" <td id=\"T_0e668_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_0e668_row1_col1\" class=\"data row1 col1\" >0.117647</td>\n",
" <td id=\"T_0e668_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_0e668_row1_col3\" class=\"data row1 col3\" >0.120000</td>\n",
" <td id=\"T_0e668_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_0e668_row1_col5\" class=\"data row1 col5\" >0.912916</td>\n",
" <td id=\"T_0e668_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_0e668_row1_col7\" class=\"data row1 col7\" >0.118812</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0e668_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
" <td id=\"T_0e668_row2_col0\" class=\"data row2 col0\" >0.228869</td>\n",
" <td id=\"T_0e668_row2_col1\" class=\"data row2 col1\" >0.135135</td>\n",
" <td id=\"T_0e668_row2_col2\" class=\"data row2 col2\" >0.884422</td>\n",
" <td id=\"T_0e668_row2_col3\" class=\"data row2 col3\" >0.500000</td>\n",
" <td id=\"T_0e668_row2_col4\" class=\"data row2 col4\" >0.849315</td>\n",
" <td id=\"T_0e668_row2_col5\" class=\"data row2 col5\" >0.818982</td>\n",
" <td id=\"T_0e668_row2_col6\" class=\"data row2 col6\" >0.363636</td>\n",
" <td id=\"T_0e668_row2_col7\" class=\"data row2 col7\" >0.212766</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1bdc4916000>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Анализ моделей:\n",
"##### MLP (многослойный перцептрон):\n",
"Точность (Accuracy): 95% (обучение и тест).\n",
"Precision и Recall: крайне низкие (0.40 и 0.02 на обучении, 0.20 и 0.02 на тесте).\n",
"F1-метрика: практически нулевая (0.038 и 0.037).\n",
"Вывод: модель хорошо определяет общий класс, но почти не замечает положительные примеры.\n",
"\n",
"##### KNN (K-ближайшие соседи):\n",
"Обучение: идеальные метрики (1.0).\n",
"Тест: резкое падение (Precision 0.118, Recall 0.12, Accuracy 91%).\n",
"Вывод: переобучение. Отлично работает на обучении, но плохо обобщает на новых данных.\n",
"\n",
"##### Random Forest (случайный лес):\n",
"Accuracy: 85% (обучение) и 82% (тест).\n",
"Precision и Recall: умеренные, но низкие на тесте (0.135 и 0.50).\n",
"F1: лучше других моделей (0.213 на тесте).\n",
"\n",
"##### Вывод: баланс метрик лучше, чем у других, но точность распознавания положительных примеров всё еще оставляет желать лучшего.\n",
"Сравнение с baseline:\n",
"Baseline (простая модель):\n",
"Accuracy 52%, Precision 0.058, Recall 0.58, F1 0.106.\n",
"Победитель по Accuracy: все модели значительно превосходят baseline.\n",
"Recall: Random Forest лучше baseline, MLP и KNN уступают.\n",
"F1-метрика: Random Forest снова впереди, но до желаемого уровня ещё далеко.\n",
"\n",
"##### Заключение:\n",
"MLP: сильно смещена, игнорирует положительные примеры.\n",
"KNN: высокая дисперсия, сильно переобучена.\n",
"Random Forest: самый сбалансированный вариант, но precision нужно улучшать.\n",
"Итог: Random Forest – лучший выбор из предложенных, но требует доработки."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Регрессия\n",
"\n",
"Разделим набор данных на на обучающую и тестовые выборки (80/20). Целевой признак - avg_glucose_level"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>13276</th>\n",
" <td>Female</td>\n",
" <td>38.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>22.6</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21346</th>\n",
" <td>Female</td>\n",
" <td>12.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>children</td>\n",
" <td>Rural</td>\n",
" <td>17.8</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>59178</th>\n",
" <td>Female</td>\n",
" <td>7.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>children</td>\n",
" <td>Urban</td>\n",
" <td>22.3</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1679</th>\n",
" <td>Male</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>NaN</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1534</th>\n",
" <td>Female</td>\n",
" <td>61.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>26.1</td>\n",
" <td>smokes</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30463</th>\n",
" <td>Male</td>\n",
" <td>29.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>29.4</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>41935</th>\n",
" <td>Male</td>\n",
" <td>34.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>33.9</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68483</th>\n",
" <td>Female</td>\n",
" <td>60.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>41.2</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38617</th>\n",
" <td>Male</td>\n",
" <td>28.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>29.9</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46527</th>\n",
" <td>Male</td>\n",
" <td>53.0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Govt_job</td>\n",
" <td>Rural</td>\n",
" <td>41.9</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4088 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" gender age hypertension heart_disease ever_married work_type \\\n",
"id \n",
"13276 Female 38.0 0 0 Yes Private \n",
"21346 Female 12.0 0 0 No children \n",
"59178 Female 7.0 0 0 No children \n",
"1679 Male 35.0 0 0 Yes Private \n",
"1534 Female 61.0 0 0 Yes Private \n",
"... ... ... ... ... ... ... \n",
"30463 Male 29.0 0 0 No Private \n",
"41935 Male 34.0 0 0 No Private \n",
"68483 Female 60.0 0 0 Yes Private \n",
"38617 Male 28.0 0 0 Yes Self-employed \n",
"46527 Male 53.0 1 1 Yes Govt_job \n",
"\n",
" Residence_type bmi smoking_status stroke \n",
"id \n",
"13276 Urban 22.6 Unknown 0 \n",
"21346 Rural 17.8 Unknown 0 \n",
"59178 Urban 22.3 Unknown 0 \n",
"1679 Rural NaN formerly smoked 0 \n",
"1534 Rural 26.1 smokes 0 \n",
"... ... ... ... ... \n",
"30463 Urban 29.4 formerly smoked 0 \n",
"41935 Rural 33.9 never smoked 0 \n",
"68483 Urban 41.2 formerly smoked 0 \n",
"38617 Urban 29.9 never smoked 0 \n",
"46527 Rural 41.9 never smoked 0 \n",
"\n",
"[4088 rows x 10 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"id\n",
"13276 71.06\n",
"21346 70.13\n",
"59178 86.75\n",
"1679 77.48\n",
"1534 99.35\n",
" ... \n",
"30463 82.93\n",
"41935 125.29\n",
"68483 65.38\n",
"38617 73.98\n",
"46527 109.51\n",
"Name: avg_glucose_level, Length: 4088, dtype: float64"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>8385</th>\n",
" <td>Male</td>\n",
" <td>37.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>35.9</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>937</th>\n",
" <td>Male</td>\n",
" <td>7.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>children</td>\n",
" <td>Urban</td>\n",
" <td>NaN</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3494</th>\n",
" <td>Female</td>\n",
" <td>80.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>26.7</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23850</th>\n",
" <td>Male</td>\n",
" <td>66.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>33.1</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31156</th>\n",
" <td>Female</td>\n",
" <td>49.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>29.8</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>71010</th>\n",
" <td>Female</td>\n",
" <td>80.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>22.8</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>39518</th>\n",
" <td>Female</td>\n",
" <td>20.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>20.7</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7780</th>\n",
" <td>Male</td>\n",
" <td>51.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>30.7</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>56137</th>\n",
" <td>Female</td>\n",
" <td>62.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>36.3</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33175</th>\n",
" <td>Female</td>\n",
" <td>57.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Govt_job</td>\n",
" <td>Urban</td>\n",
" <td>28.5</td>\n",
" <td>Unknown</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1022 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" gender age hypertension heart_disease ever_married work_type \\\n",
"id \n",
"8385 Male 37.0 0 0 Yes Private \n",
"937 Male 7.0 0 0 No children \n",
"3494 Female 80.0 0 0 Yes Private \n",
"23850 Male 66.0 0 0 Yes Private \n",
"31156 Female 49.0 0 0 Yes Private \n",
"... ... ... ... ... ... ... \n",
"71010 Female 80.0 0 0 No Self-employed \n",
"39518 Female 20.0 0 0 No Private \n",
"7780 Male 51.0 0 0 Yes Self-employed \n",
"56137 Female 62.0 0 0 Yes Private \n",
"33175 Female 57.0 0 0 Yes Govt_job \n",
"\n",
" Residence_type bmi smoking_status stroke \n",
"id \n",
"8385 Urban 35.9 Unknown 0 \n",
"937 Urban NaN Unknown 0 \n",
"3494 Rural 26.7 Unknown 0 \n",
"23850 Urban 33.1 never smoked 0 \n",
"31156 Urban 29.8 never smoked 0 \n",
"... ... ... ... ... \n",
"71010 Urban 22.8 never smoked 0 \n",
"39518 Rural 20.7 never smoked 0 \n",
"7780 Urban 30.7 never smoked 0 \n",
"56137 Urban 36.3 Unknown 0 \n",
"33175 Urban 28.5 Unknown 1 \n",
"\n",
"[1022 rows x 10 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"id\n",
"8385 90.78\n",
"937 87.94\n",
"3494 102.90\n",
"23850 103.01\n",
"31156 105.99\n",
" ... \n",
"71010 57.57\n",
"39518 78.94\n",
"7780 75.73\n",
"56137 88.32\n",
"33175 110.52\n",
"Name: avg_glucose_level, Length: 1022, dtype: float64"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"features = ['gender', 'age', 'hypertension', 'heart_disease', 'ever_married', 'work_type', 'Residence_type', 'bmi', 'smoking_status', 'stroke']\n",
"target = 'avg_glucose_level'\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(df[features], df[target], test_size=0.2, random_state=random_state)\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выберем ориентир для задачи регрессии. Для этого применим алгоритм правила нуля, т.е . в каждом случае в качестве предсказания выберем среднее значение из области значений целевого признака."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Baseline RMSE: 44.12711275645952\n",
"Baseline RMAE: 5.662154850745081\n",
"Baseline R2: -0.0010729515309222393\n"
]
}
],
"source": [
"import math\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n",
"\n",
"# Базовое предсказание: среднее значение по y_train\n",
"baseline_predictions = [y_train.mean()] * len(y_test)\n",
"\n",
"# Вычисление метрик качества для ориентира\n",
"baseline_rmse = math.sqrt(\n",
" mean_squared_error(y_test, baseline_predictions)\n",
" )\n",
"baseline_rmae = math.sqrt(\n",
" mean_absolute_error(y_test, baseline_predictions)\n",
" )\n",
"baseline_r2 = r2_score(y_test, baseline_predictions)\n",
"\n",
"print('Baseline RMSE:', baseline_rmse)\n",
"print('Baseline RMAE:', baseline_rmae)\n",
"print('Baseline R2:', baseline_r2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Метрики:\n",
"##### RMSE: корень из MSE, измеряет среднеквадратическую ошибку. \n",
"Удобен, так как результат в тех же единицах, что и данные. Штрафует за большие отклонения. Хорош для задач, где важна интерпретируемость.\n",
"##### RMAE: корень из MAE, измеряет среднюю абсолютную ошибку. \n",
"Менее чувствителен к выбросам, что полезно для данных с редкими сильными отклонениями.\n",
"##### R²: коэффициент детерминации, показывает, насколько модель объясняет изменчивость данных. Значение ближе к 1 — модель хорошо описывает данные\n",
"Используется для сравнения моделей на одинаковых данных.\n",
"Эти метрики помогают оценить, насколько точна модель по сравнению с простым усреднением."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Сформируем конвейер для регрессии"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"columns_to_drop = []\n",
"columns_not_to_modify = [\"hypertension\", \"heart_disease\", \"stroke\", \"avg_glucose_level\"]\n",
"\n",
"num_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop\n",
" and column not in columns_not_to_modify\n",
" and df[column].dtype != \"object\"\n",
"]\n",
"\n",
"cat_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop\n",
" and column not in columns_not_to_modify\n",
" and df[column].dtype == \"object\"\n",
"]\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"pipeline_end_reg = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Теперь проверим работу конвейера:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>bmi</th>\n",
" <th>gender_Male</th>\n",
" <th>gender_Other</th>\n",
" <th>ever_married_Yes</th>\n",
" <th>work_type_Never_worked</th>\n",
" <th>work_type_Private</th>\n",
" <th>work_type_Self-employed</th>\n",
" <th>work_type_children</th>\n",
" <th>Residence_type_Urban</th>\n",
" <th>smoking_status_formerly smoked</th>\n",
" <th>smoking_status_never smoked</th>\n",
" <th>smoking_status_smokes</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>13276</th>\n",
" <td>-0.236211</td>\n",
" <td>-0.826056</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21346</th>\n",
" <td>-1.386874</td>\n",
" <td>-1.455413</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>59178</th>\n",
" <td>-1.608155</td>\n",
" <td>-0.865391</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1679</th>\n",
" <td>-0.368980</td>\n",
" <td>-0.104918</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1534</th>\n",
" <td>0.781682</td>\n",
" <td>-0.367150</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30463</th>\n",
" <td>-0.634518</td>\n",
" <td>0.065532</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>41935</th>\n",
" <td>-0.413236</td>\n",
" <td>0.655554</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68483</th>\n",
" <td>0.737426</td>\n",
" <td>1.612701</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38617</th>\n",
" <td>-0.678774</td>\n",
" <td>0.131090</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46527</th>\n",
" <td>0.427632</td>\n",
" <td>1.704482</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4088 rows × 16 columns</p>\n",
"</div>"
],
"text/plain": [
" age bmi gender_Male gender_Other ever_married_Yes \\\n",
"id \n",
"13276 -0.236211 -0.826056 0.0 0.0 1.0 \n",
"21346 -1.386874 -1.455413 0.0 0.0 0.0 \n",
"59178 -1.608155 -0.865391 0.0 0.0 0.0 \n",
"1679 -0.368980 -0.104918 1.0 0.0 1.0 \n",
"1534 0.781682 -0.367150 0.0 0.0 1.0 \n",
"... ... ... ... ... ... \n",
"30463 -0.634518 0.065532 1.0 0.0 0.0 \n",
"41935 -0.413236 0.655554 1.0 0.0 0.0 \n",
"68483 0.737426 1.612701 0.0 0.0 1.0 \n",
"38617 -0.678774 0.131090 1.0 0.0 1.0 \n",
"46527 0.427632 1.704482 1.0 0.0 1.0 \n",
"\n",
" work_type_Never_worked work_type_Private work_type_Self-employed \\\n",
"id \n",
"13276 0.0 1.0 0.0 \n",
"21346 0.0 0.0 0.0 \n",
"59178 0.0 0.0 0.0 \n",
"1679 0.0 1.0 0.0 \n",
"1534 0.0 1.0 0.0 \n",
"... ... ... ... \n",
"30463 0.0 1.0 0.0 \n",
"41935 0.0 1.0 0.0 \n",
"68483 0.0 1.0 0.0 \n",
"38617 0.0 0.0 1.0 \n",
"46527 0.0 0.0 0.0 \n",
"\n",
" work_type_children Residence_type_Urban \\\n",
"id \n",
"13276 0.0 1.0 \n",
"21346 1.0 0.0 \n",
"59178 1.0 1.0 \n",
"1679 0.0 0.0 \n",
"1534 0.0 0.0 \n",
"... ... ... \n",
"30463 0.0 1.0 \n",
"41935 0.0 0.0 \n",
"68483 0.0 1.0 \n",
"38617 0.0 1.0 \n",
"46527 0.0 0.0 \n",
"\n",
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
"id \n",
"13276 0.0 0.0 \n",
"21346 0.0 0.0 \n",
"59178 0.0 0.0 \n",
"1679 1.0 0.0 \n",
"1534 0.0 0.0 \n",
"... ... ... \n",
"30463 1.0 0.0 \n",
"41935 0.0 1.0 \n",
"68483 1.0 0.0 \n",
"38617 0.0 1.0 \n",
"46527 0.0 1.0 \n",
"\n",
" smoking_status_smokes hypertension heart_disease stroke \n",
"id \n",
"13276 0.0 0 0 0 \n",
"21346 0.0 0 0 0 \n",
"59178 0.0 0 0 0 \n",
"1679 0.0 0 0 0 \n",
"1534 1.0 0 0 0 \n",
"... ... ... ... ... \n",
"30463 0.0 0 0 0 \n",
"41935 0.0 0 0 0 \n",
"68483 0.0 0 0 0 \n",
"38617 0.0 0 0 0 \n",
"46527 0.0 1 1 0 \n",
"\n",
"[4088 rows x 16 columns]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end_reg.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end_reg.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подберем оптимальные гиперпараметры для каждой из выбранных моделей методом поиска по сетке и сформируем их набор.\n",
"\n",
"knn -- k-ближайших соседей\n",
"\n",
"random_forest -- метод случайного леса (набор деревьев решений)\n",
"\n",
"mlp -- многослойный персептрон (нейронная сеть)"
]
},
{
"cell_type": "code",
2024-12-21 05:10:04 +04:00
"execution_count": 15,
2024-12-21 04:54:14 +04:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\code\\mai\\labs\\AIM-PIbd-31-Bakalskaya-E-D\\lab_4\\venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие параметры для knn: {'n_jobs': -1, 'n_neighbors': 30, 'weights': 'uniform'}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\code\\mai\\labs\\AIM-PIbd-31-Bakalskaya-E-D\\lab_4\\venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-12-21 05:10:04 +04:00
"Лучшие параметры для random_forest: {'criterion': 'squared_error', 'max_depth': 7, 'max_features': 'sqrt', 'n_estimators': 250, 'n_jobs': -1, 'random_state': 9}\n",
"Лучшие параметры для mlp: {'alpha': np.float64(1e-06), 'early_stopping': False, 'hidden_layer_sizes': np.int64(13), 'max_iter': 1000, 'random_state': 9, 'solver': 'adam'}\n"
2024-12-21 04:54:14 +04:00
]
}
],
"source": [
"# Словарь с вариантами гиперпараметров для каждой модели\n",
"param_grids = {\n",
" \"knn\": {\n",
" \"n_neighbors\": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], \n",
" \"weights\": ['uniform', 'distance'],\n",
" \"n_jobs\": [-1]\n",
" },\n",
" \"random_forest\": {\n",
" \"n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" \"max_features\": [\"sqrt\", \"log2\", 2],\n",
" \"max_depth\": [2, 3, 4, 5, 6, 7, 8, 9, 10],\n",
" \"criterion\": [\"squared_error\", \"absolute_error\", \"poisson\"],\n",
" \"random_state\": [random_state],\n",
" \"n_jobs\": [-1]\n",
" },\n",
" \"mlp\": {\n",
" \"solver\": ['adam'], \n",
" \"max_iter\": [1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000], \n",
" \"alpha\": 10.0 ** -np.arange(1, 10), \n",
" \"hidden_layer_sizes\":np.arange(10, 15), \n",
" \"early_stopping\": [True, False],\n",
" \"random_state\": [random_state]\n",
" }\n",
"}\n",
"\n",
"# Создаем экземпляры моделей\n",
"models = {\n",
" \"knn\": neighbors.KNeighborsRegressor(),\n",
" \"random_forest\": ensemble.RandomForestRegressor(),\n",
" \"mlp\": neural_network.MLPRegressor()\n",
"}\n",
"\n",
"# Словарь для хранения моделей с их лучшими параметрами\n",
"class_models = {}\n",
"\n",
"# Выполнение поиска по сетке для каждой модели\n",
"for model_name, model in models.items():\n",
" # Создаем GridSearchCV для текущей модели\n",
" gs_optimizer = GridSearchCV(estimator=model, param_grid=param_grids[model_name], scoring='neg_mean_squared_error', n_jobs=-1)\n",
" \n",
" # Обучаем GridSearchCV\n",
" gs_optimizer.fit(preprocessed_df, y_train.values.ravel())\n",
" \n",
" # Получаем лучшие параметры\n",
" best_params = gs_optimizer.best_params_\n",
" print(f\"Лучшие параметры для {model_name}: {best_params}\")\n",
" \n",
" class_models[model_name] = {\n",
" \"model\": model.set_params(**best_params) # Настраиваем модель с лучшими параметрами\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Далее обучим модели и оценим их качество."
]
},
{
"cell_type": "code",
2024-12-21 05:10:04 +04:00
"execution_count": 16,
2024-12-21 04:54:14 +04:00
"metadata": {},
2024-12-21 05:10:04 +04:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: knn\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
2024-12-21 04:54:14 +04:00
"source": [
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" \n",
" model = class_models[model_name][\"model\"]\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end_reg), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_pred = model_pipeline.predict(X_train)\n",
" y_test_pred = model_pipeline.predict(X_test)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"train_preds\"] = y_train_pred\n",
" class_models[model_name][\"preds\"] = y_test_pred\n",
" \n",
" class_models[model_name][\"RMSE_train\"] = math.sqrt(\n",
" mean_squared_error(y_train, y_train_pred)\n",
" )\n",
" class_models[model_name][\"RMSE_test\"] = math.sqrt(\n",
" mean_squared_error(y_test, y_test_pred)\n",
" )\n",
" class_models[model_name][\"RMAE_test\"] = math.sqrt(\n",
" mean_absolute_error(y_test, y_test_pred)\n",
" )\n",
" class_models[model_name][\"R2_test\"] = r2_score(y_test, y_test_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"RMSE, RMAE, R2:"
]
},
{
"cell_type": "code",
2024-12-21 05:10:04 +04:00
"execution_count": 17,
2024-12-21 04:54:14 +04:00
"metadata": {},
2024-12-21 05:10:04 +04:00
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_c9eb0_row0_col0, #T_c9eb0_row2_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_c9eb0_row0_col1, #T_c9eb0_row1_col0 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c9eb0_row0_col2, #T_c9eb0_row2_col3 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c9eb0_row0_col3, #T_c9eb0_row2_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c9eb0_row1_col1 {\n",
" background-color: #20938c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c9eb0_row1_col2 {\n",
" background-color: #b42e8d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c9eb0_row1_col3 {\n",
" background-color: #c8437b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c9eb0_row2_col0 {\n",
" background-color: #75d054;\n",
" color: #000000;\n",
"}\n",
"</style>\n",
"<table id=\"T_c9eb0\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" > </th>\n",
" <th id=\"T_c9eb0_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
" <th id=\"T_c9eb0_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
" <th id=\"T_c9eb0_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
" <th id=\"T_c9eb0_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_c9eb0_level0_row0\" class=\"row_heading level0 row0\" >mlp</th>\n",
" <td id=\"T_c9eb0_row0_col0\" class=\"data row0 col0\" >42.583378</td>\n",
" <td id=\"T_c9eb0_row0_col1\" class=\"data row0 col1\" >40.922194</td>\n",
" <td id=\"T_c9eb0_row0_col2\" class=\"data row0 col2\" >5.533579</td>\n",
" <td id=\"T_c9eb0_row0_col3\" class=\"data row0 col3\" >0.139061</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c9eb0_level0_row1\" class=\"row_heading level0 row1\" >random_forest</th>\n",
" <td id=\"T_c9eb0_row1_col0\" class=\"data row1 col0\" >40.324186</td>\n",
" <td id=\"T_c9eb0_row1_col1\" class=\"data row1 col1\" >41.085298</td>\n",
" <td id=\"T_c9eb0_row1_col2\" class=\"data row1 col2\" >5.544678</td>\n",
" <td id=\"T_c9eb0_row1_col3\" class=\"data row1 col3\" >0.132184</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c9eb0_level0_row2\" class=\"row_heading level0 row2\" >knn</th>\n",
" <td id=\"T_c9eb0_row2_col0\" class=\"data row2 col0\" >42.166860</td>\n",
" <td id=\"T_c9eb0_row2_col1\" class=\"data row2 col1\" >41.821704</td>\n",
" <td id=\"T_c9eb0_row2_col2\" class=\"data row2 col2\" >5.550619</td>\n",
" <td id=\"T_c9eb0_row2_col3\" class=\"data row2 col3\" >0.100796</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1bdc1acae70>"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
2024-12-21 04:54:14 +04:00
"source": [
"reg_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Результаты графиками:"
]
},
{
"cell_type": "code",
2024-12-21 05:10:04 +04:00
"execution_count": 18,
2024-12-21 04:54:14 +04:00
"metadata": {},
2024-12-21 05:10:04 +04:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: knn\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAIjCAYAAAAJLyrXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3xT9foH8M852UnTvRhllb0EBXHjYCgKKo6Lyu+K4uIWEERFFEXFq4jIEHBeRa9XXNd1RUUQEVRQUEDBsnfpXtn7nN8foaHpgKZNmo7P+/XyJT1Jkydpcs73+Y7nK8iyLIOIiIiIiIjqTIx2AERERERERM0NEykiIiIiIqIQMZEiIiIiIiIKERMpIiIiIiKiEDGRIiIiIiIiChETKSIiIiIiohAxkSIiIiIiIgoREykiIiIiIqIQMZEiIiIiIiIKERMpIiJqcQRBwJNPPhny7x05cgSCIODtt9+u1/NOmDABMTEx9fpdIiJqXphIERFRRLz99tsQBAGCIOCnn36qdrssy8jIyIAgCLjmmmuiECEREVH9MZEiIqKI0mq1WLlyZbXjGzZsQE5ODjQaTRSiIiIiahgmUkREFFGjRo3Cxx9/DK/XG3R85cqVOOecc5Cenh6lyIiIiOqPiRQREUXULbfcgpKSEqxduzZwzO1247///S9uvfXWGn/HZrNhxowZyMjIgEajQY8ePbBgwQLIshx0P5fLhenTpyMlJQVGoxFjxoxBTk5OjY954sQJ3HnnnUhLS4NGo0GfPn3w1ltvnTF+j8eDPXv2IC8vL4RXfcqOHTuQkpKCSy+9FFarFQDQqVMnXHPNNfjpp59w7rnnQqvVokuXLvj3v/8d9LsV0yN//vlnPPDAA0hJSYHBYMD111+PoqKiesVDREThwUSKiIgiqlOnTjj//PPx/vvvB4598803MJlMGDduXLX7y7KMMWPGYNGiRbjyyiuxcOFC9OjRAw899BAeeOCBoPveddddWLx4MUaMGIF58+ZBpVLh6quvrvaYBQUFOO+88/Ddd99h8uTJWLJkCbp27YqJEydi8eLFp43/xIkT6NWrF2bNmhXya9+6dSsuv/xyDBw4EN98801QIYoDBw7gxhtvxPDhw/Hiiy8iISEBEyZMwF9//VXtcaZMmYI//vgDc+bMwaRJk/Dll19i8uTJIcdDRETho4x2AERE1PLdeuutmDVrFhwOB3Q6Hd577z0MHToUbdu2rXbf//3vf/j+++/xzDPP4LHHHgMAZGVl4aabbsKSJUswefJkZGZm4o8//sB//vMf/OMf/8Dy5csD97vtttvw559/Bj3mY489Bp/Ph507dyIpKQkAcN999+GWW27Bk08+iXvvvRc6nS6sr/nnn3/GqFGjcPHFF+OTTz6pthZs79692LhxIy6++GIAwM0334yMjAysWLECCxYsCLpvUlIS1qxZA0EQAACSJOGll16CyWRCXFxcWOMmIqK64YgUERFF3M033wyHw4FVq1bBYrFg1apVtU7r+/rrr6FQKDB16tSg4zNmzIAsy/jmm28C9wNQ7X7Tpk0L+lmWZXzyyScYPXo0ZFlGcXFx4L+RI0fCZDJh27ZttcbeqVMnyLIcUkn09evXY+TIkbjiiivw6aef1lhQo3fv3oEkCgBSUlLQo0cPHDp0qNp977nnnkASBQAXX3wxfD4fjh49WueYiIgovDgiRUREEZeSkoJhw4Zh5cqVsNvt8Pl8uPHGG2u879GjR9G2bVsYjcag47169QrcXvF/URSRmZkZdL8ePXoE/VxUVITy8nK8/vrreP3112t8zsLCwnq9rpo4nU5cffXVOOecc/DRRx9Bqaz5UtuhQ4dqxxISElBWVnbG+yYkJABAjfclIqLGwUSKiIgaxa233oq7774b+fn5uOqqqxAfH98ozytJEgBg/PjxuP3222u8T//+/cP2fBqNBqNGjcIXX3yB1atX17pHlkKhqPF41YIaod6XiIgaBxMpIiJqFNdffz3uvfde/PLLL/jwww9rvV/Hjh3x3XffwWKxBI1K7dmzJ3B7xf8lScLBgweDRqH27t0b9HgVFf18Ph+GDRsWzpdUI0EQ8N577+Haa6/FTTfdhG+++QaXXnppxJ+XiIgaF9dIERFRo4iJicErr7yCJ598EqNHj671fqNGjYLP58OyZcuCji9atAiCIOCqq64CgMD/X3rppaD7Va3Cp1AocMMNN+CTTz7Brl27qj3fmcqI16f8uVqtxqefforBgwdj9OjR2LJlS51/l4iImgeOSBERUaOpbWpdZaNHj8Zll12Gxx57DEeOHMFZZ52FNWvW4IsvvsC0adMCa6IGDBiAW265BS+//DJMJhMuuOACrFu3DgcOHKj2mPPmzcP69esxZMgQ3H333ejduzdKS0uxbds2fPfddygtLa01nory57fffntIBSd0Oh1WrVqFyy+/HFdddRU2bNiAvn371vn3iYioaeOIFBERNSmiKOJ///sfpk2bhlWrVmHatGnIzs7GCy+8gIULFwbd96233sLUqVOxevVqPPzww/B4PPjqq6+qPWZaWhq2bNmCO+64A59++mlgL6nS0lI8//zzEXstsbGx+Pbbb5Geno7hw4fXmOQREVHzJMhcqUpERERERBQSjkgRERERERGFiIkUERERERFRiJhIERERERERhYiJFBERERERUYiYSBEREREREYWIiRQREREREVGIuCEvAEmSkJubC6PRCEEQoh0OERERERFFiSzLsFgsaNu2LUSx9nEnJlIAcnNzkZGREe0wiIiIiIioiTh+/Djat29f6+1MpAAYjUYA/jcrNjY2ytEQEREREVG0mM1mZGRkBHKE2jCRAgLT+WJjY5lIERERERHRGZf8sNgEERERERFRiJhIERERERERhYiJFBERERERUYiYSBEREREREYWIiRQREREREVGImEgRERERERGFiIkUERERERFRiJhIERERERERhYiJFBERERERUYiYSBEREREREYWIiRQREREREVGImEgRERERERGFiIkUERERERFRiJhIERERERERhYiJFBERERERUYiYSBEREREREYWIiRQREREREUWFLMs4cOBAtMOoFyZSRERERETU6Hbt2oXLLrsMgwYNQmFhYbTDCRkTKSIiIiIiajQmkwnTpk3DgAEDsGHDBphMJsyaNSvaYYVMGe0AiIiIiIio5ZMkCe+++y4efvjhoBGozMxMjB07NoqR1Q8TKSIiIiIiiqgdO3YgKysLmzZtChzT6XR49NFH8eCDD0Kr1UYxuvphIkVERERERBFhMpnw6KOP4tVXX4UkSYHjY8eOxcKFC9GxY8coRtcwTKSIiIiIiCgiZFnGf//730AS1b17dyxduhQjRoyIcmQNx2ITREREREQUEfHx8Zg/fz4MBgPmzZuHnTt3togkCmAiRUREREREYVBSUoIpU6bgxIkTQcf/7//+D/v378fMmTOhVqujFF34cWofERERERHVm8/nwxtvvIHHHnsMpaWlKCkpwcqVKwO3i6KINm3aRDHCyOCIFBERERER1cvmzZtx7rnnYtKkSSgtLQUArFq1Cnl5eVGOLPKYSBERERERUUgKCwtx55134oILLsC2bdsCx8ePH4+9e/e2yBGoqji1j4iIiIiI6sTr9eKVV17B448/DpPJFDjer18/LF++HBdffHEUo2tcTKSIiIiIiKhOrrvuOnz11VeBn+Pi4jB37lxMmjQJSmXrSi04tY+IiIiIiOrk73//e+DfEyZMwN69ezFlypRWl0QBHJEiIiIiIqIaeDweWCwWJCYmBo7ddNNN+Omnn3DLLbfg/PPPj2J00ccRKSIiIiIiCvLDDz9g4MCBmDhxYtBxQRDw0ksvtfokCmAiRUREREREJ+Xk5OCWW27BZZddhr/++guff/45Vq9eHe2wmiQmUkRERERErZzb7cb8+fPRs2dPfPDBB4HjgwcPRmpqahQja7q4RoqIiIiIqBVbu3YtpkyZgr179waOJSUlYd68ebjzzjshihx7qQnfFSIiIiKiVujYsWO44YYbMGLEiEASJYoi/vGPf2Dfvn246667mES
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: random_forest\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAIjCAYAAAAJLyrXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3hT9f4H8Pc5WU3TNN0to6VQ9kZRRFQQGYKCigsUBcGFDEFUhB8KilcBkSG4ryLXK67ruqIyRMR7BS8ooGDZG0p3m6bZ4/z+CD02dNC0aZO079fz8NB+c5p8cnJyzvl8pyBJkgQiIiIiIiKqMTHYARAREREREYUbJlJERERERER+YiJFRERERETkJyZSREREREREfmIiRURERERE5CcmUkRERERERH5iIkVEREREROQnJlJERERERER+YiJFRERERETkJyZSRERULUEQMH/+fL//7sSJExAEAe+9917AY2oo4fAeXnrpJbRp0wYKhQI9e/YMdjhERE0GEykiojDw3nvvQRAECIKA//73vxUelyQJqampEAQBN954YxAipGDYuHEjnnzySfTr1w+rV6/GCy+8EOyQKsjMzMT8+fNx4sSJYIdCRBRQymAHQERENRcREYG1a9fiqquu8infunUrzpw5A41GE6TIKBh++OEHiKKId955B2q1OtjhVCozMxPPPvssBgwYgPT09GCHQ0QUMGyRIiIKI8OHD8enn34Kl8vlU7527VpceumlSElJCVJkDcNsNgc7hJCSm5sLrVYbsCRKkiRYrdaAPBcRUWPHRIqIKIyMGTMGBQUF2LRpk1zmcDjwr3/9C3fddVelf2M2mzFz5kykpqZCo9GgQ4cOWLJkCSRJ8tnObrdjxowZSExMhF6vx8iRI3HmzJlKn/Ps2bOYMGECkpOTodFo0KVLF7z77rsXjd/pdOLAgQM4d+7cRbcdP348oqKicPToUQwfPhx6vR533303AOA///kPbr/9dqSlpUGj0SA1NRUzZsyokASUPcfZs2dx8803IyoqComJiXj88cfhdrt9ti0uLsb48eNhMBgQExODcePGobi4uNLYfvjhB1x99dXQ6XSIiYnBTTfdhP379/tsM3/+fAiCgEOHDmHs2LEwGAxITEzE008/DUmScPr0adx0002Ijo5GSkoKXn755Yvuk/IEQcDq1athNpvlbp9lY7lcLhcWLFiAjIwMaDQapKenY86cObDb7T7PkZ6ejhtvvBEbNmxA7969odVq8eabb8r7Y/r06fJx07ZtWyxatAgej8fnOT766CNceuml0Ov1iI6ORrdu3bBixQoA3i6pt99+OwDg2muvleP88ccf/XqvREShiIkUEVEYSU9PR9++ffHhhx/KZd999x2MRiNGjx5dYXtJkjBy5EgsW7YM119/PZYuXYoOHTrgiSeewGOPPeaz7f3334/ly5djyJAhWLhwIVQqFW644YYKz5mTk4MrrrgC33//PaZMmYIVK1agbdu2mDhxIpYvX15t/GfPnkWnTp0we/bsGr1fl8uFoUOHIikpCUuWLMGtt94KAPj0009hsVgwadIkrFy5EkOHDsXKlStx7733VngOt9uNoUOHIj4+HkuWLEH//v3x8ssv46233vLZTzfddBPef/99jB07Fs8//zzOnDmDcePGVXi+77//HkOHDkVubi7mz5+Pxx57DNu2bUO/fv0qHQd05513wuPxYOHChejTpw+ef/55LF++HIMHD0aLFi2waNEitG3bFo8//jh++umnGu0XAHj//fdx9dVXQ6PR4P3338f777+Pa665BoD3s3zmmWdwySWXYNmyZejfvz9efPHFSo+RgwcPYsyYMRg8eDBWrFiBnj17wmKxoH///vjnP/+Je++9F6+88gr69euH2bNn+xw3mzZtwpgxYxAbG4tFixZh4cKFGDBgAH7++WcAwDXXXINp06YBAObMmSPH2alTpxq/TyKikCUREVHIW716tQRA2rlzp7Rq1SpJr9dLFotFkiRJuv3226Vrr71WkiRJatWqlXTDDTfIf/fll19KAKTnn3/e5/luu+02SRAE6ciRI5IkSdKePXskANIjjzzis91dd90lAZDmzZsnl02cOFFq1qyZlJ+f77Pt6NGjJYPBIMd1/PhxCYC0evVqeZuysnHjxl30PY8bN04CID311FMVHit7jfJefPFFSRAE6eTJkxWe47nnnvPZtlevXtKll14q/162nxYvXiyXuVwu6eqrr67wHnr27CklJSVJBQUFctnvv/8uiaIo3XvvvXLZvHnzJADSgw8+6POcLVu2lARBkBYuXCiXFxUVSVqttkb7pbxx48ZJOp3Op6zss7z//vt9yh9//HEJgPTDDz/IZa1atZIASOvXr/fZdsGCBZJOp5MOHTrkU/7UU09JCoVCOnXqlCRJkvToo49K0dHRksvlqjLGTz/9VAIgbdmyxa/3RkQU6tgiRUQUZu644w5YrVasW7cOJpMJ69atq7Jb37fffguFQiG3CpSZOXMmJEnCd999J28HoMJ206dP9/ldkiR89tlnGDFiBCRJQn5+vvxv6NChMBqN2LVrV5Wxp6enQ5Ikv6YTnzRpUoUyrVYr/2w2m5Gfn48rr7wSkiRh9+7dFbZ/+OGHfX6/+uqrcezYMfn3b7/9Fkql0ue1FAoFpk6d6vN3586dw549ezB+/HjExcXJ5d27d8fgwYPl/Vje/fff7/OcvXv3hiRJmDhxolweExODDh06+MRUW2UxXNjiOHPmTADAN99841PeunVrDB061Kfs008/xdVXX43Y2Fifz3jQoEFwu91yy1lMTAzMZrNPV1MioqaCs/YREYWZxMREDBo0CGvXroXFYoHb7cZtt91W6bYnT55E8+bNodfrfcrLuladPHlS/l8URWRkZPhs16FDB5/f8/LyUFxcjLfeesuna1x5ubm5tXpflVEqlWjZsmWF8lOnTuGZZ57Bv//9bxQVFfk8ZjQafX6PiIhAYmKiT1lsbKzP3508eRLNmjVDVFSUz3YXvv+y/XVhOeDdpxs2bIDZbIZOp5PL09LSfLYzGAyIiIhAQkJChfKCgoIKz+uvss+ybdu2PuUpKSmIiYmR30OZ1q1bV3iOw4cP448//qiw38qUfcaPPPIIPvnkEwwbNgwtWrTAkCFDcMcdd+D666+v8/sgIgp1TKSIiMLQXXfdhQceeADZ2dkYNmwYYmJiGuR1yyYaGDt2bKXjhwBv60ygaDQaiKJv5wm3243BgwejsLAQs2bNQseOHaHT6XD27FmMHz++wmQICoUiYPHURmWvX1VM0gUTgNSFIAg12q58614Zj8eDwYMH48knn6z0b9q3bw8ASEpKwp49e7BhwwZ89913+O6777B69Wrce++9WLNmTe2DJyIKA0ykiIjC0C233IKHHnoIv/zyCz7++OMqt2vVqhW+//57mEwmn1apAwcOyI+X/e/xeHD06FGf1paDBw/6PF/ZjH5utxuDBg0K5Fuqsb179+LQoUNYs2aNz+QSdele1qpVK2zevBmlpaU+rVIXvv+y/XVhOeDdpwkJCT6tUcFQ9lkePnzYZ1KHnJwcFBcXy++hOhkZGSgtLa3RZ6xWqzFixAiMGDECHo8HjzzyCN588008/fTTaNu2bY0TOiKicMMxUkREYSgqKgqvv/465s+fjxEjRlS53fDhw+F2u7Fq1Sqf8mXLlkEQBAwbNgwA5P9feeUVn+0unIVPoVDg1ltvxWeffYZ9+/ZVeL28vLxq4/Zn+vOqlLXmlG+9kSRJnnK7NoYPHw6Xy4XXX39dLnO73Vi5cqXPds2aNUPPnj2xZs0an6nR9+3bh40bN2L48OG1jiFQymK48LNbunQpAFQ6E+OF7rjjDmzfvh0bNmyo8FhxcbG8jtmFXRFFUZRbJMumWi9LLKuaSp6IKFyxRYqIKExV1bWuvBEjRuDaa6/F//3f/+HEiRPo0aMHNm7ciK+++grTp0+Xx0T17NkTY8aMwWuvvQaj0Ygrr7wSmzdvxpEjRyo858KFC7Flyxb06dMHDzzwADp37ozCwkLs2rUL33//PQoLC6uMp2z683H
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: mlp\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAIjCAYAAAAJLyrXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3hUZdoG8Puc6S2dJJRAIHQBQUFUXLGgKAqr2FBZC9gwgCAq4qrYVgGpAtZVdHdFXVdXP3FtiwgWXFFAQXqTkB6STG9nzvn+mGRkSAKZZCYzSe7fdXlJ3pnMPDOZOec8b3leQVEUBURERERERNRoYrwDICIiIiIiam2YSBEREREREUWIiRQREREREVGEmEgRERERERFFiIkUERERERFRhJhIERERERERRYiJFBERERERUYSYSBEREREREUWIiRQREREREVGEmEgREVGbIwgCHnvssYh/79ChQxAEAa+//nrUY2qMr776CoIg4KuvvorL8xMRUeMxkSIioph4/fXXIQgCBEHAN998U+d2RVGQk5MDQRBw+eWXxyFCIiKipmMiRUREMaXX67F69eo67evXr8eRI0eg0+niEBUREVHzMJEiIqKYGjNmDN59911IkhTWvnr1apx++unIzs6OU2RERERNx0SKiIhi6vrrr8fRo0fxxRdfhNp8Ph/+9a9/4YYbbqj3d5xOJ2bNmoWcnBzodDr06dMHCxcuhKIoYffzer2YOXMmOnToAIvFgnHjxuHIkSP1PmZhYSEmTZqErKws6HQ6nHLKKXjttddOGr/f78euXbtQXFx80vvecsstMJvNOHz4MC6//HKYzWZ07twZK1euBABs27YNF1xwAUwmE7p161bvSN3xzjvvPAwYMAA//fQTzj77bBgMBnTv3h0vvvjiSX+XiIhih4kUERHFVG5uLs466yy89dZbobZPPvkEVqsVEyZMqHN/RVEwbtw4LFmyBJdccgkWL16MPn364P7778e9994bdt/bbrsNS5cuxcUXX4x58+ZBo9Hgsssuq/OYpaWlOPPMM/Hf//4XU6dOxbJly9CzZ09MnjwZS5cuPWH8hYWF6NevH+bMmdOo1xsIBHDppZciJycHCxYsQG5uLqZOnYrXX38dl1xyCYYOHYr58+fDYrHgpptuwsGDB0/6mFVVVRgzZgxOP/10LFiwAF26dMGUKVMalQgSEVGMKERERDGwatUqBYCyadMmZcWKFYrFYlFcLpeiKIpyzTXXKOeff76iKIrSrVs35bLLLgv93gcffKAAUJ566qmwx7v66qsVQRCUffv2KYqiKFu3blUAKHfffXfY/W644QYFgDJ37txQ2+TJk5WOHTsqFRUVYfedMGGCkpycHIrr4MGDCgBl1apVofvUtt18880nfc0333yzAkB5+umnQ21VVVWKwWBQBEFQ3n777VD7rl276sS5bt06BYCybt26UNvIkSMVAMqiRYtCbV6vVxk8eLCSmZmp+Hy+k8ZFRETRxxEpIiKKuWuvvRZutxtr1qyB3W7HmjVrGpzW95///AcqlQrTp08Pa581axYURcEnn3wSuh+AOvebMWNG2M+KouC9997D2LFjoSgKKioqQv+NHj0aVqsVmzdvbjD23NxcKIoSUUn02267LfTvlJQU9OnTByaTCddee22ovU+fPkhJScGBAwdO+nhqtRp33nln6GetVos777wTZWVl+OmnnxodFxERRY863gEQEVHb16FDB4waNQqrV6+Gy+VCIBDA1VdfXe99f/vtN3Tq1AkWiyWsvV+/fqHba/8viiLy8vLC7tenT5+wn8vLy1FdXY2XX34ZL7/8cr3PWVZW1qTXVR+9Xo8OHTqEtSUnJ6NLly4QBKFOe1VV1Ukfs1OnTjCZTGFtvXv3BhDc++rMM89sZtRERBQpJlJERNQibrjhBtx+++0oKSnBpZdeipSUlBZ5XlmWAQATJ07EzTffXO99Bg0aFLXnU6lUEbUrxxXQICKi1oGJFBERtYgrr7wSd955J77//nu88847Dd6vW7du+O9//wu73R42KrVr167Q7bX/l2UZ+/fvDxuF2r17d9jj1Vb0CwQCGDVqVDRfUospKiqC0+kMG5Xas2cPgODUQyIianlcI0VERC3CbDbjhRdewGOPPYaxY8c2eL8xY8YgEAhgxYoVYe1LliyBIAi49NJLASD0/+eeey7sfsdX4VOpVLjqqqvw3nvvYfv27XWer7y8/IRxR1L+PFYkScJLL70U+tnn8+Gll15Chw4dcPrpp8ctLiKi9owjUkRE1GIamlp3rLFjx+L888/Hn//8Zxw6dAinnnoqPv/8c3z44YeYMWNGaE3U4MGDcf311+P555+H1WrF2WefjbVr12Lfvn11HnPevHlYt24dhg8fjttvvx39+/dHZWUlNm/ejP/+97+orKxsMJ7a8uc333xzRAUnoqlTp06YP38+Dh06hN69e+Odd97B1q1b8fLLL0Oj0cQlJiKi9o6JFBERJRRRFPF///d/ePTRR/HOO+9g1apVyM3NxbPPPotZs2aF3fe1115Dhw4d8Oabb+KDDz7ABRdcgI8//hg5OTlh98vKysIPP/yAJ554Au+//z6ef/55pKen45RTTsH8+fNb8uU1SWpqKt544w1MmzYNr7zyCrKysrBixQrcfvvt8Q6NiKjdEhSuciUiIkpY5513HioqKuqdlkhERPHDNVJEREREREQRYiJFREREREQUISZSREREREREEeIaKSIiIiIioghxRIqIiIiIiChCTKSIiIiIiIgixH2kAMiyjKKiIlgsFgiCEO9wiIiIiIgoThRFgd1uR6dOnSCKDY87MZECUFRUVGfzRiIiIiIiar8KCgrQpUuXBm9nIgXAYrEACL5ZSUlJcY6GiIiIiIjixWazIScnJ5QjNISJFBCazpeUlMREioiIiIiITrrkh8UmiIiIiIiIIsREioiIiIiIKEJMpIiIiIiIiCLERIqIiIiIiChCTKSIiIiIiIgixESKiIiIiIgoQkykiIiIiIiIIsREioiIiIiIKEJMpIiIiIiIiCLERIqIiIiIiChCTKSIiIiIiIgixESKiIiIiIgoQkykiIiIiIiIIsREioiIiIiIKEJMpIiIiIiIiCLERIqIiIiIiChCTKSIiIiIiCguFEXBvn374h1GkzCRIiIiIiKiFrd9+3acf/75GDp0KMrKyuIdTsSYSBERERERUYuxWq2YMWMGBg8ejPXr18NqtWLOnDnxDiti6ngHQEREREREbZ8sy/j73/+OBx54IGwEKi8vD+PHj49jZE3DRIqIiIiIiGJq69atyM/Px3fffRdqMxgMeOihh3DfffdBr9fHMbqmYSJFREREREQxYbVa8dBDD+HFF1+ELMuh9vHjx2Px4sXo1q1bHKNrHiZSREREREQUE4qi4F//+lcoierduzeWL1+Oiy++OM6RNR+LTRARERERUUykpKRgwYIFMJlMmDdvHrZt29YmkiiAiRQREREREUXB0aNHMW3aNBQWFoa1/+lPf8LevXsxe/ZsaLXaOEUXfZzaR0RERERETRYIBPDKK6/gz3/+MyorK3H06FGsXr06dLsoiujYsWMcI4wNjkgREREREVGTbNy4EWeccQamTJmCyspKAMCaNWtQXFwc58hij4kUERERERFFpKysDJMmTcLZZ5+NzZs3h9onTpyI3bt3t8kRqONxah8RERERETWKJEl44YUX8Mgjj8BqtYbaBw4ciJUrV+IPf/hDHKNrWUykiIiIiIioUa644gp8/PHHoZ+Tk5Px5JNPYsqUKVCr21dqwal9RERERETUKDfddFPo37fccgt2796NadOmtbskCuCIFBERERER1cPv98NutyMtLS3Uds011+Cbb77B9ddfj7POOiuO0cUfR6SIiIiIiCjMV199hSFDhmDy5Mlh7YIg4Lnnnmv3SRTARIqIiIiIiGocOXIE119/Pc4//3z8+uuv+OCDD/Dpp5/GO6yExESKiIiIiKid8/l8WLBgAfr27Yu333471D5s2DBkZmbGMbLExTVSRERERETt2BdffIFp06Zh9+7dobb09HTMmzc
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2024-12-21 04:54:14 +04:00
"source": [
"# Создаем графики для всех моделей\n",
"for model_name, model_data in class_models.items():\n",
" print(f\"Model: {model_name}\")\n",
" y_pred = model_data[\"preds\"]\n",
" plt.figure(figsize=(10, 6))\n",
" plt.scatter(y_test, y_pred, alpha=0.5)\n",
" plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2)\n",
" plt.xlabel('Фактический уровень глюкозы')\n",
" plt.ylabel('Прогнозируемый уровень глюкозы')\n",
" plt.title(f\"Model: {model_name}\")\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Н а представленных графиках можно заметить, что модели в целом не демонстрируют высокого качества. Визуализация их предсказаний показывает сильное рассеивание вокруг идеальной линии y = x, что указывает на значительные отклонения предсказаний от фактических значений.\n",
"\n",
"Тем не менее ориентир, хоть возможно и не столь значительно, каждая из моделей превосходит по всем показателям. Особенно заметное улучшение в \n",
"R2, которая переходит из отрицательного значения в положительное, что говорит о том, что модели хотя бы частично объясняют дисперсию данных. \n",
"\n",
"Кроме того, можно сказать, что все модели имеет умеренную дисперсию и не сильно подвержены переобучению, потому что разница между RMSE на обучении и тесте незначительна.\n",
"\n",
"Итоговые выводы:\n",
"- Наиболее качественная модель: MLP, так как она показывает наименьшее значение RMSE и наибольшее значение R2, что указывает на лучшую точность и объяснение дисперсии целевой переменной.\n",
"\n",
"- Random Forest: Близок по производительности к MLP, с чуть большим RMSE, но является более устойчивой моделью с небольшими отклонениями между обучением и тестом.\n",
"\n",
"- KNN: Худшая модель, демонстрирующая наибольшие ошибки и низкое R2, что указывает на необходимость улучшения или использования другой модели для данной задачи."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}