3299 lines
184 KiB
Plaintext
3299 lines
184 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Данные по инсультам\n",
|
||
"\n",
|
||
"Выведем информацию о столбцах датасета:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>gender</th>\n",
|
||
" <th>age</th>\n",
|
||
" <th>hypertension</th>\n",
|
||
" <th>heart_disease</th>\n",
|
||
" <th>ever_married</th>\n",
|
||
" <th>work_type</th>\n",
|
||
" <th>Residence_type</th>\n",
|
||
" <th>avg_glucose_level</th>\n",
|
||
" <th>bmi</th>\n",
|
||
" <th>smoking_status</th>\n",
|
||
" <th>stroke</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>id</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>9046</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>67.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>228.69</td>\n",
|
||
" <td>36.6</td>\n",
|
||
" <td>formerly smoked</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>51676</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>61.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Self-employed</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>202.21</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>31112</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>80.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>105.92</td>\n",
|
||
" <td>32.5</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>60182</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>49.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>171.23</td>\n",
|
||
" <td>34.4</td>\n",
|
||
" <td>smokes</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1665</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>79.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Self-employed</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>174.12</td>\n",
|
||
" <td>24.0</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>18234</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>80.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>83.75</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>44873</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>81.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Self-employed</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>125.20</td>\n",
|
||
" <td>40.0</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>19723</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>35.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Self-employed</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>82.99</td>\n",
|
||
" <td>30.6</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>37544</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>51.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>166.29</td>\n",
|
||
" <td>25.6</td>\n",
|
||
" <td>formerly smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>44679</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>44.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Govt_job</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>85.28</td>\n",
|
||
" <td>26.2</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>5110 rows × 11 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" gender age hypertension heart_disease ever_married work_type \\\n",
|
||
"id \n",
|
||
"9046 Male 67.0 0 1 Yes Private \n",
|
||
"51676 Female 61.0 0 0 Yes Self-employed \n",
|
||
"31112 Male 80.0 0 1 Yes Private \n",
|
||
"60182 Female 49.0 0 0 Yes Private \n",
|
||
"1665 Female 79.0 1 0 Yes Self-employed \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"18234 Female 80.0 1 0 Yes Private \n",
|
||
"44873 Female 81.0 0 0 Yes Self-employed \n",
|
||
"19723 Female 35.0 0 0 Yes Self-employed \n",
|
||
"37544 Male 51.0 0 0 Yes Private \n",
|
||
"44679 Female 44.0 0 0 Yes Govt_job \n",
|
||
"\n",
|
||
" Residence_type avg_glucose_level bmi smoking_status stroke \n",
|
||
"id \n",
|
||
"9046 Urban 228.69 36.6 formerly smoked 1 \n",
|
||
"51676 Rural 202.21 NaN never smoked 1 \n",
|
||
"31112 Rural 105.92 32.5 never smoked 1 \n",
|
||
"60182 Urban 171.23 34.4 smokes 1 \n",
|
||
"1665 Rural 174.12 24.0 never smoked 1 \n",
|
||
"... ... ... ... ... ... \n",
|
||
"18234 Urban 83.75 NaN never smoked 0 \n",
|
||
"44873 Urban 125.20 40.0 never smoked 0 \n",
|
||
"19723 Rural 82.99 30.6 never smoked 0 \n",
|
||
"37544 Rural 166.29 25.6 formerly smoked 0 \n",
|
||
"44679 Urban 85.28 26.2 Unknown 0 \n",
|
||
"\n",
|
||
"[5110 rows x 11 columns]"
|
||
]
|
||
},
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"from sklearn import set_config\n",
|
||
"\n",
|
||
"set_config(transform_output=\"pandas\")\n",
|
||
"\n",
|
||
"random_state=9\n",
|
||
"\n",
|
||
"df = pd.read_csv(\"./csv/option4.csv\", index_col=\"id\")\n",
|
||
"\n",
|
||
"df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Бизнес-цели\n",
|
||
"#### Классификация\n",
|
||
"##### Цель: сделать модель, которая на основе данных про здоровье, образ жизни и соцдем-факторы будет предсказывать риск инсульта.\n",
|
||
"\n",
|
||
"Применение:\n",
|
||
"Для врачей: чтобы находить пациентов с высоким риском и вовремя их спасать.\n",
|
||
"Для медсистем: можно встроить в карты пациентов, чтобы система сама предупреждала врача.\n",
|
||
"Для людей: повышать осведомленность о факторах риска и как их избежать.\n",
|
||
"\n",
|
||
"#### Регрессия\n",
|
||
"##### Цель: сделать модель, которая будет предсказывать уровень глюкозы на основе тех же факторов. Поможет отслеживать изменения и оценивать риски.\n",
|
||
"\n",
|
||
"Применение:\n",
|
||
"Для врачей: находить пациентов с риском диабета и другими проблемами, чтобы сразу назначать профилактику.\n",
|
||
"Для медсистем: встроить в записи пациентов, чтобы врачи видели уровень глюкозы даже без лабораторий.\n",
|
||
"Для населения: обучать, как lifestyle влияет на глюкозу, и рекомендовать изменения.\n",
|
||
"\n",
|
||
"### Качество моделей\n",
|
||
"#### Классификация (инсульт):\n",
|
||
"\n",
|
||
"Признаки норм: возраст, гипертония, сердце — инфа полезная, модель будет неплохо различать риск.\n",
|
||
"Проблемы: нет данных про гены, детальную историю болезней, питание, физнагрузки. Это режет потолок точности.\n",
|
||
"\n",
|
||
"#### Регрессия (глюкоза):\n",
|
||
"Признаки норм: возраст, курение, давление — связь с глюкозой есть, что-то предскажет.\n",
|
||
"Проблемы: не хватает данных про еду, активность, гормоны, поэтому точность ограничена."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Классификация\n",
|
||
"\n",
|
||
"Разделим набор данных на на обучающую и тестовые выборки (80/20). Целевой признак - stroke"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from typing import Tuple\n",
|
||
"from pandas import DataFrame\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"\n",
|
||
"def split_stratified_into_train_val_test(\n",
|
||
" df_input,\n",
|
||
" stratify_colname=\"y\",\n",
|
||
" frac_train=0.6,\n",
|
||
" frac_val=0.15,\n",
|
||
" frac_test=0.25,\n",
|
||
" random_state=None,\n",
|
||
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
||
" \"\"\"\n",
|
||
" Splits a Pandas dataframe into three subsets (train, val, and test)\n",
|
||
" following fractional ratios provided by the user, where each subset is\n",
|
||
" stratified by the values in a specific column (that is, each subset has\n",
|
||
" the same relative frequency of the values in the column). It performs this\n",
|
||
" splitting by running train_test_split() twice.\n",
|
||
"\n",
|
||
" Parameters\n",
|
||
" ----------\n",
|
||
" df_input : Pandas dataframe\n",
|
||
" Input dataframe to be split.\n",
|
||
" stratify_colname : str\n",
|
||
" The name of the column that will be used for stratification. Usually\n",
|
||
" this column would be for the label.\n",
|
||
" frac_train : float\n",
|
||
" frac_val : float\n",
|
||
" frac_test : float\n",
|
||
" The ratios with which the dataframe will be split into train, val, and\n",
|
||
" test data. The values should be expressed as float fractions and should\n",
|
||
" sum to 1.0.\n",
|
||
" random_state : int, None, or RandomStateInstance\n",
|
||
" Value to be passed to train_test_split().\n",
|
||
"\n",
|
||
" Returns\n",
|
||
" -------\n",
|
||
" df_train, df_val, df_test :\n",
|
||
" Dataframes containing the three splits.\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
" if frac_train + frac_val + frac_test != 1.0:\n",
|
||
" raise ValueError(\n",
|
||
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
|
||
" % (frac_train, frac_val, frac_test)\n",
|
||
" )\n",
|
||
"\n",
|
||
" if stratify_colname not in df_input.columns:\n",
|
||
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
|
||
"\n",
|
||
" X = df_input # Contains all columns.\n",
|
||
" y = df_input[\n",
|
||
" [stratify_colname]\n",
|
||
" ] # Dataframe of just the column on which to stratify.\n",
|
||
"\n",
|
||
" # Split original dataframe into train and temp dataframes.\n",
|
||
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
|
||
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
|
||
" )\n",
|
||
"\n",
|
||
" if frac_val <= 0:\n",
|
||
" assert len(df_input) == len(df_train) + len(df_temp)\n",
|
||
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
|
||
"\n",
|
||
" # Split the temp dataframe into val and test dataframes.\n",
|
||
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
|
||
" df_val, df_test, y_val, y_test = train_test_split(\n",
|
||
" df_temp,\n",
|
||
" y_temp,\n",
|
||
" stratify=y_temp,\n",
|
||
" test_size=relative_frac_test,\n",
|
||
" random_state=random_state,\n",
|
||
" )\n",
|
||
"\n",
|
||
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
|
||
" return df_train, df_val, df_test, y_train, y_val, y_test"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>gender</th>\n",
|
||
" <th>age</th>\n",
|
||
" <th>hypertension</th>\n",
|
||
" <th>heart_disease</th>\n",
|
||
" <th>ever_married</th>\n",
|
||
" <th>work_type</th>\n",
|
||
" <th>Residence_type</th>\n",
|
||
" <th>avg_glucose_level</th>\n",
|
||
" <th>bmi</th>\n",
|
||
" <th>smoking_status</th>\n",
|
||
" <th>stroke</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>id</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>22159</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>54.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>No</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>97.06</td>\n",
|
||
" <td>28.5</td>\n",
|
||
" <td>formerly smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>8920</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>51.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Self-employed</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>76.35</td>\n",
|
||
" <td>33.5</td>\n",
|
||
" <td>formerly smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>65507</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>33.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>55.72</td>\n",
|
||
" <td>38.2</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>43196</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>52.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Self-employed</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>59.54</td>\n",
|
||
" <td>42.2</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>59745</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>27.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>76.74</td>\n",
|
||
" <td>53.9</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>66546</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>20.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>No</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>80.08</td>\n",
|
||
" <td>25.1</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>68798</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>58.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>59.86</td>\n",
|
||
" <td>28.0</td>\n",
|
||
" <td>formerly smoked</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>61409</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>32.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>No</td>\n",
|
||
" <td>Govt_job</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>58.24</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>formerly smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>69259</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>77.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>100.85</td>\n",
|
||
" <td>29.5</td>\n",
|
||
" <td>smokes</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>17231</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>24.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>No</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>90.42</td>\n",
|
||
" <td>24.3</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>4088 rows × 11 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" gender age hypertension heart_disease ever_married work_type \\\n",
|
||
"id \n",
|
||
"22159 Female 54.0 1 0 No Private \n",
|
||
"8920 Female 51.0 0 0 Yes Self-employed \n",
|
||
"65507 Male 33.0 0 0 Yes Private \n",
|
||
"43196 Female 52.0 0 0 Yes Self-employed \n",
|
||
"59745 Female 27.0 0 0 Yes Private \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"66546 Female 20.0 0 0 No Private \n",
|
||
"68798 Female 58.0 0 0 Yes Private \n",
|
||
"61409 Male 32.0 1 0 No Govt_job \n",
|
||
"69259 Female 77.0 0 0 Yes Private \n",
|
||
"17231 Female 24.0 0 0 No Private \n",
|
||
"\n",
|
||
" Residence_type avg_glucose_level bmi smoking_status stroke \n",
|
||
"id \n",
|
||
"22159 Urban 97.06 28.5 formerly smoked 0 \n",
|
||
"8920 Rural 76.35 33.5 formerly smoked 0 \n",
|
||
"65507 Rural 55.72 38.2 never smoked 0 \n",
|
||
"43196 Urban 59.54 42.2 Unknown 0 \n",
|
||
"59745 Urban 76.74 53.9 Unknown 0 \n",
|
||
"... ... ... ... ... ... \n",
|
||
"66546 Urban 80.08 25.1 never smoked 0 \n",
|
||
"68798 Rural 59.86 28.0 formerly smoked 1 \n",
|
||
"61409 Urban 58.24 NaN formerly smoked 0 \n",
|
||
"69259 Rural 100.85 29.5 smokes 0 \n",
|
||
"17231 Urban 90.42 24.3 never smoked 0 \n",
|
||
"\n",
|
||
"[4088 rows x 11 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>stroke</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>id</th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>22159</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>8920</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>65507</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>43196</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>59745</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>66546</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>68798</th>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>61409</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>69259</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>17231</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>4088 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" stroke\n",
|
||
"id \n",
|
||
"22159 0\n",
|
||
"8920 0\n",
|
||
"65507 0\n",
|
||
"43196 0\n",
|
||
"59745 0\n",
|
||
"... ...\n",
|
||
"66546 0\n",
|
||
"68798 1\n",
|
||
"61409 0\n",
|
||
"69259 0\n",
|
||
"17231 0\n",
|
||
"\n",
|
||
"[4088 rows x 1 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>gender</th>\n",
|
||
" <th>age</th>\n",
|
||
" <th>hypertension</th>\n",
|
||
" <th>heart_disease</th>\n",
|
||
" <th>ever_married</th>\n",
|
||
" <th>work_type</th>\n",
|
||
" <th>Residence_type</th>\n",
|
||
" <th>avg_glucose_level</th>\n",
|
||
" <th>bmi</th>\n",
|
||
" <th>smoking_status</th>\n",
|
||
" <th>stroke</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>id</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>18072</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>39.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Govt_job</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>107.47</td>\n",
|
||
" <td>21.3</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>67063</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>62.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Self-employed</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>130.56</td>\n",
|
||
" <td>36.1</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>40387</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>17.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>No</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>77.46</td>\n",
|
||
" <td>24.0</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>18032</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>62.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>90.61</td>\n",
|
||
" <td>25.8</td>\n",
|
||
" <td>smokes</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>5478</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>60.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Self-employed</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>203.04</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>smokes</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>57710</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>50.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>112.25</td>\n",
|
||
" <td>21.6</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>63043</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>27.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>No</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>61.80</td>\n",
|
||
" <td>26.8</td>\n",
|
||
" <td>formerly smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>63986</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>60.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>153.48</td>\n",
|
||
" <td>37.3</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>28461</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>15.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>No</td>\n",
|
||
" <td>Never_worked</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>79.59</td>\n",
|
||
" <td>28.4</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>54975</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>7.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>No</td>\n",
|
||
" <td>Self-employed</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>64.06</td>\n",
|
||
" <td>18.9</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>1022 rows × 11 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" gender age hypertension heart_disease ever_married work_type \\\n",
|
||
"id \n",
|
||
"18072 Female 39.0 0 0 Yes Govt_job \n",
|
||
"67063 Male 62.0 0 0 Yes Self-employed \n",
|
||
"40387 Female 17.0 0 0 No Private \n",
|
||
"18032 Male 62.0 0 1 Yes Private \n",
|
||
"5478 Female 60.0 0 0 Yes Self-employed \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"57710 Female 50.0 0 0 Yes Private \n",
|
||
"63043 Female 27.0 0 0 No Private \n",
|
||
"63986 Male 60.0 0 0 Yes Private \n",
|
||
"28461 Male 15.0 0 0 No Never_worked \n",
|
||
"54975 Male 7.0 0 0 No Self-employed \n",
|
||
"\n",
|
||
" Residence_type avg_glucose_level bmi smoking_status stroke \n",
|
||
"id \n",
|
||
"18072 Urban 107.47 21.3 Unknown 0 \n",
|
||
"67063 Urban 130.56 36.1 Unknown 0 \n",
|
||
"40387 Rural 77.46 24.0 Unknown 0 \n",
|
||
"18032 Rural 90.61 25.8 smokes 0 \n",
|
||
"5478 Urban 203.04 NaN smokes 0 \n",
|
||
"... ... ... ... ... ... \n",
|
||
"57710 Rural 112.25 21.6 Unknown 0 \n",
|
||
"63043 Urban 61.80 26.8 formerly smoked 0 \n",
|
||
"63986 Rural 153.48 37.3 never smoked 0 \n",
|
||
"28461 Rural 79.59 28.4 Unknown 0 \n",
|
||
"54975 Rural 64.06 18.9 Unknown 0 \n",
|
||
"\n",
|
||
"[1022 rows x 11 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>stroke</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>id</th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>18072</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>67063</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>40387</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>18032</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>5478</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>57710</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>63043</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>63986</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>28461</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>54975</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>1022 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" stroke\n",
|
||
"id \n",
|
||
"18072 0\n",
|
||
"67063 0\n",
|
||
"40387 0\n",
|
||
"18032 0\n",
|
||
"5478 0\n",
|
||
"... ...\n",
|
||
"57710 0\n",
|
||
"63043 0\n",
|
||
"63986 0\n",
|
||
"28461 0\n",
|
||
"54975 0\n",
|
||
"\n",
|
||
"[1022 rows x 1 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
|
||
" df, stratify_colname=\"stroke\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
|
||
")\n",
|
||
"\n",
|
||
"display(\"X_train\", X_train)\n",
|
||
"display(\"y_train\", y_train)\n",
|
||
"\n",
|
||
"display(\"X_test\", X_test)\n",
|
||
"display(\"y_test\", y_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Выберем ориентир для задачи классификации. Для этого применим алгоритм случайного предсказания, т.е. в каждом случае в качестве предсказания выберем случайный класс."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Baseline Accuracy: 0.5\n",
|
||
"Baseline Precision: 0.05758157389635317\n",
|
||
"Baseline Recall: 0.6\n",
|
||
"Baseline F1 Score: 0.10507880910683012\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score\n",
|
||
"\n",
|
||
"# Получаем уникальные классы для целевого признака из тренировочного набора данных\n",
|
||
"unique_classes = np.unique(y_train)\n",
|
||
"\n",
|
||
"# Генерируем случайные предсказания, выбирая случайное значение из области значений целевого признака\n",
|
||
"random_predictions = np.random.choice(unique_classes, size=len(y_test))\n",
|
||
"\n",
|
||
"# Вычисление метрик для ориентира\n",
|
||
"baseline_accuracy = accuracy_score(y_test, random_predictions)\n",
|
||
"baseline_precision = precision_score(y_test, random_predictions)\n",
|
||
"baseline_recall = recall_score(y_test, random_predictions)\n",
|
||
"baseline_f1 = f1_score(y_test, random_predictions)\n",
|
||
"\n",
|
||
"print('Baseline Accuracy:', baseline_accuracy)\n",
|
||
"print('Baseline Precision:', baseline_precision)\n",
|
||
"print('Baseline Recall:', baseline_recall)\n",
|
||
"print('Baseline F1 Score:', baseline_f1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Метрики модели\n",
|
||
"- Accuracy: доля правильных предсказаний от общего числа примеров. Простая, но бесполезная метрика в задачах с дисбалансом классов — не учитывает, как модель работает с редким классом.\n",
|
||
"- Precision: доля правильных предсказаний положительного класса среди всех предсказанных положительных. Полезна, если критичны ложные срабатывания (например, чтобы не ошибаться с инсультом).\n",
|
||
"- Recall: доля найденных объектов положительного класса среди всех реальных примеров положительного класса. Помогает понять, насколько хорошо модель \"ловит\" положительный класс. Важна, чтобы минимизировать пропуски инсультов.\n",
|
||
"- F1 Score: гармоническое среднее между precision и recall. Учитывает и точность, и полноту, что важно в задачах с несбалансированными классами.\n",
|
||
"Эти метрики показывают разные аспекты работы модели: от общего уровня точности до способности находить редкие классы и балансировать между precision и recall. Это позволяет оценить модель всесторонне."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Сформируем конвейер для классификации"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.compose import ColumnTransformer\n",
|
||
"from sklearn.discriminant_analysis import StandardScaler\n",
|
||
"from sklearn.impute import SimpleImputer\n",
|
||
"from sklearn.pipeline import Pipeline\n",
|
||
"from sklearn.preprocessing import OneHotEncoder\n",
|
||
"\n",
|
||
"columns_to_drop = [\"work_type\", \"stroke\"]\n",
|
||
"columns_not_to_modify = [\"hypertension\", \"heart_disease\"]\n",
|
||
"\n",
|
||
"num_columns = [\n",
|
||
" column\n",
|
||
" for column in df.columns\n",
|
||
" if column not in columns_to_drop\n",
|
||
" and column not in columns_not_to_modify\n",
|
||
" and df[column].dtype != \"object\"\n",
|
||
"]\n",
|
||
"\n",
|
||
"cat_columns = [\n",
|
||
" column\n",
|
||
" for column in df.columns\n",
|
||
" if column not in columns_to_drop\n",
|
||
" and column not in columns_not_to_modify\n",
|
||
" and df[column].dtype == \"object\"\n",
|
||
"]\n",
|
||
"\n",
|
||
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
||
"num_scaler = StandardScaler()\n",
|
||
"preprocessing_num = Pipeline(\n",
|
||
" [\n",
|
||
" (\"imputer\", num_imputer),\n",
|
||
" (\"scaler\", num_scaler),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
||
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
||
"preprocessing_cat = Pipeline(\n",
|
||
" [\n",
|
||
" (\"imputer\", cat_imputer),\n",
|
||
" (\"encoder\", cat_encoder),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"features_preprocessing = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
|
||
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\"\n",
|
||
")\n",
|
||
"\n",
|
||
"drop_columns = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\",\n",
|
||
")\n",
|
||
"\n",
|
||
"pipeline_end = Pipeline(\n",
|
||
" [\n",
|
||
" (\"features_preprocessing\", features_preprocessing),\n",
|
||
" (\"drop_columns\", drop_columns),\n",
|
||
" ]\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Теперь проверим работу конвейера:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>age</th>\n",
|
||
" <th>avg_glucose_level</th>\n",
|
||
" <th>bmi</th>\n",
|
||
" <th>gender_Male</th>\n",
|
||
" <th>gender_Other</th>\n",
|
||
" <th>ever_married_Yes</th>\n",
|
||
" <th>Residence_type_Urban</th>\n",
|
||
" <th>smoking_status_formerly smoked</th>\n",
|
||
" <th>smoking_status_never smoked</th>\n",
|
||
" <th>smoking_status_smokes</th>\n",
|
||
" <th>hypertension</th>\n",
|
||
" <th>heart_disease</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>id</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>22159</th>\n",
|
||
" <td>0.472344</td>\n",
|
||
" <td>-0.194427</td>\n",
|
||
" <td>-0.059214</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>8920</th>\n",
|
||
" <td>0.339807</td>\n",
|
||
" <td>-0.653763</td>\n",
|
||
" <td>0.587887</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>65507</th>\n",
|
||
" <td>-0.455418</td>\n",
|
||
" <td>-1.111325</td>\n",
|
||
" <td>1.196162</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>43196</th>\n",
|
||
" <td>0.383986</td>\n",
|
||
" <td>-1.026600</td>\n",
|
||
" <td>1.713843</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>59745</th>\n",
|
||
" <td>-0.720492</td>\n",
|
||
" <td>-0.645113</td>\n",
|
||
" <td>3.228060</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>66546</th>\n",
|
||
" <td>-1.029746</td>\n",
|
||
" <td>-0.571034</td>\n",
|
||
" <td>-0.499243</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>68798</th>\n",
|
||
" <td>0.649060</td>\n",
|
||
" <td>-1.019502</td>\n",
|
||
" <td>-0.123924</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>61409</th>\n",
|
||
" <td>-0.499597</td>\n",
|
||
" <td>-1.055433</td>\n",
|
||
" <td>-0.098040</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>69259</th>\n",
|
||
" <td>1.488464</td>\n",
|
||
" <td>-0.110367</td>\n",
|
||
" <td>0.070206</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>17231</th>\n",
|
||
" <td>-0.853030</td>\n",
|
||
" <td>-0.341699</td>\n",
|
||
" <td>-0.602779</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>4088 rows × 12 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" age avg_glucose_level bmi gender_Male gender_Other \\\n",
|
||
"id \n",
|
||
"22159 0.472344 -0.194427 -0.059214 0.0 0.0 \n",
|
||
"8920 0.339807 -0.653763 0.587887 0.0 0.0 \n",
|
||
"65507 -0.455418 -1.111325 1.196162 1.0 0.0 \n",
|
||
"43196 0.383986 -1.026600 1.713843 0.0 0.0 \n",
|
||
"59745 -0.720492 -0.645113 3.228060 0.0 0.0 \n",
|
||
"... ... ... ... ... ... \n",
|
||
"66546 -1.029746 -0.571034 -0.499243 0.0 0.0 \n",
|
||
"68798 0.649060 -1.019502 -0.123924 0.0 0.0 \n",
|
||
"61409 -0.499597 -1.055433 -0.098040 1.0 0.0 \n",
|
||
"69259 1.488464 -0.110367 0.070206 0.0 0.0 \n",
|
||
"17231 -0.853030 -0.341699 -0.602779 0.0 0.0 \n",
|
||
"\n",
|
||
" ever_married_Yes Residence_type_Urban smoking_status_formerly smoked \\\n",
|
||
"id \n",
|
||
"22159 0.0 1.0 1.0 \n",
|
||
"8920 1.0 0.0 1.0 \n",
|
||
"65507 1.0 0.0 0.0 \n",
|
||
"43196 1.0 1.0 0.0 \n",
|
||
"59745 1.0 1.0 0.0 \n",
|
||
"... ... ... ... \n",
|
||
"66546 0.0 1.0 0.0 \n",
|
||
"68798 1.0 0.0 1.0 \n",
|
||
"61409 0.0 1.0 1.0 \n",
|
||
"69259 1.0 0.0 0.0 \n",
|
||
"17231 0.0 1.0 0.0 \n",
|
||
"\n",
|
||
" smoking_status_never smoked smoking_status_smokes hypertension \\\n",
|
||
"id \n",
|
||
"22159 0.0 0.0 1 \n",
|
||
"8920 0.0 0.0 0 \n",
|
||
"65507 1.0 0.0 0 \n",
|
||
"43196 0.0 0.0 0 \n",
|
||
"59745 0.0 0.0 0 \n",
|
||
"... ... ... ... \n",
|
||
"66546 1.0 0.0 0 \n",
|
||
"68798 0.0 0.0 0 \n",
|
||
"61409 0.0 0.0 1 \n",
|
||
"69259 0.0 1.0 0 \n",
|
||
"17231 1.0 0.0 0 \n",
|
||
"\n",
|
||
" heart_disease \n",
|
||
"id \n",
|
||
"22159 0 \n",
|
||
"8920 0 \n",
|
||
"65507 0 \n",
|
||
"43196 0 \n",
|
||
"59745 0 \n",
|
||
"... ... \n",
|
||
"66546 0 \n",
|
||
"68798 0 \n",
|
||
"61409 0 \n",
|
||
"69259 0 \n",
|
||
"17231 0 \n",
|
||
"\n",
|
||
"[4088 rows x 12 columns]"
|
||
]
|
||
},
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
|
||
"preprocessed_df = pd.DataFrame(\n",
|
||
" preprocessing_result,\n",
|
||
" columns=pipeline_end.get_feature_names_out(),\n",
|
||
")\n",
|
||
"\n",
|
||
"preprocessed_df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Подберем оптимальные гиперпараметры для каждой из выбранных моделей методом поиска по сетке и сформируем их набор.\n",
|
||
"\n",
|
||
"knn -- k-ближайших соседей\n",
|
||
"\n",
|
||
"random_forest -- метод случайного леса (набор деревьев решений)\n",
|
||
"\n",
|
||
"mlp -- многослойный персептрон (нейронная сеть)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"d:\\code\\mai\\labs\\AIM-PIbd-31-Bakalskaya-E-D\\lab_4\\venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
|
||
" _data = np.array(data, dtype=dtype, copy=copy,\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Лучшие параметры для knn: {'n_neighbors': 1, 'weights': 'uniform'}\n",
|
||
"Лучшие параметры для random_forest: {'class_weight': 'balanced_subsample', 'criterion': 'entropy', 'max_depth': 7, 'max_features': 'sqrt', 'n_estimators': 50, 'random_state': 9}\n",
|
||
"Лучшие параметры для mlp: {'alpha': np.float64(0.1), 'early_stopping': True, 'hidden_layer_sizes': np.int64(14), 'max_iter': 1000, 'random_state': 9, 'solver': 'adam'}\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.model_selection import GridSearchCV\n",
|
||
"from sklearn import neighbors, ensemble, neural_network\n",
|
||
"\n",
|
||
"# Словарь с вариантами гиперпараметров для каждой модели\n",
|
||
"param_grids = {\n",
|
||
" \"knn\": {\n",
|
||
" \"n_neighbors\": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], \n",
|
||
" \"weights\": ['uniform', 'distance']\n",
|
||
" },\n",
|
||
" \"random_forest\": {\n",
|
||
" \"n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
|
||
" \"max_features\": [\"sqrt\", \"log2\", 2],\n",
|
||
" \"max_depth\": [2, 3, 4, 5, 6, 7, 8, 9, 10],\n",
|
||
" \"criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
|
||
" \"random_state\": [random_state],\n",
|
||
" \"class_weight\": [\"balanced\", \"balanced_subsample\"]\n",
|
||
" },\n",
|
||
" \"mlp\": {\n",
|
||
" \"solver\": ['adam'], \n",
|
||
" \"max_iter\": [1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000], \n",
|
||
" \"alpha\": 10.0 ** -np.arange(1, 10), \n",
|
||
" \"hidden_layer_sizes\":np.arange(10, 15), \n",
|
||
" \"early_stopping\": [True, False],\n",
|
||
" \"random_state\": [random_state]\n",
|
||
" }\n",
|
||
"}\n",
|
||
"\n",
|
||
"# Создаем экземпляры моделей\n",
|
||
"models = {\n",
|
||
" \"knn\": neighbors.KNeighborsClassifier(),\n",
|
||
" \"random_forest\": ensemble.RandomForestClassifier(),\n",
|
||
" \"mlp\": neural_network.MLPClassifier()\n",
|
||
"}\n",
|
||
"\n",
|
||
"# Словарь для хранения моделей с их лучшими параметрами\n",
|
||
"class_models = {}\n",
|
||
"\n",
|
||
"# Выполнение поиска по сетке для каждой модели\n",
|
||
"for model_name, model in models.items():\n",
|
||
" # Создаем GridSearchCV для текущей модели\n",
|
||
" gs_optimizer = GridSearchCV(estimator=model, param_grid=param_grids[model_name], scoring=\"f1\", n_jobs=-1)\n",
|
||
" \n",
|
||
" # Обучаем GridSearchCV\n",
|
||
" gs_optimizer.fit(preprocessed_df, y_train.values.ravel())\n",
|
||
" \n",
|
||
" # Получаем лучшие параметры\n",
|
||
" best_params = gs_optimizer.best_params_\n",
|
||
" print(f\"Лучшие параметры для {model_name}: {best_params}\")\n",
|
||
" \n",
|
||
" class_models[model_name] = {\n",
|
||
" \"model\": model.set_params(**best_params) # Настраиваем модель с лучшими параметрами\n",
|
||
" }"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### ЖЕСТЬ ЭТА ХРЕНЬ 12 МИНУТ СОЗДАВАЛАСЬ..."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"что ж теперь... обучим модели разными способами модели и посмотрим на качество обучения"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Model: knn\n",
|
||
"Model: random_forest\n",
|
||
"Model: mlp\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.metrics import confusion_matrix\n",
|
||
"\n",
|
||
"for model_name in class_models.keys():\n",
|
||
" print(f\"Model: {model_name}\")\n",
|
||
" model = class_models[model_name][\"model\"]\n",
|
||
"\n",
|
||
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
|
||
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
|
||
"\n",
|
||
" y_train_predict = model_pipeline.predict(X_train)\n",
|
||
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
|
||
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
|
||
"\n",
|
||
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
||
" class_models[model_name][\"probs\"] = y_test_probs\n",
|
||
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
||
"\n",
|
||
" class_models[model_name][\"Precision_train\"] = precision_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Precision_test\"] = precision_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Recall_train\"] = recall_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Recall_test\"] = recall_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Accuracy_train\"] = accuracy_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Accuracy_test\"] = accuracy_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" ) \n",
|
||
" class_models[model_name][\"F1_train\"] = f1_score(y_train, y_train_predict)\n",
|
||
" class_models[model_name][\"F1_test\"] = f1_score(y_test, y_test_predict)\n",
|
||
" class_models[model_name][\"Confusion_matrix\"] = confusion_matrix(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Матрицы неточностей:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 1200x1000 with 6 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"# Цветовая схема: светло-голубой для правильных, бледно-красный для ошибок\n",
|
||
"def custom_color_map(c_matrix):\n",
|
||
" colors = np.empty_like(c_matrix, dtype=str)\n",
|
||
" for i in range(c_matrix.shape[0]):\n",
|
||
" for j in range(c_matrix.shape[1]):\n",
|
||
" if i == j:\n",
|
||
" colors[i, j] = \"#add8e6\" # Светло-голубой\n",
|
||
" else:\n",
|
||
" colors[i, j] = \"#f08080\" # Бледно-красный\n",
|
||
" return colors\n",
|
||
"\n",
|
||
"fig, ax = plt.subplots(2, 2, figsize=(12, 10))\n",
|
||
"\n",
|
||
"for index, (key, model_info) in enumerate(class_models.items()):\n",
|
||
" c_matrix = model_info[\"Confusion_matrix\"]\n",
|
||
"\n",
|
||
" # Генерация кастомных цветов\n",
|
||
" disp = ConfusionMatrixDisplay(\n",
|
||
" confusion_matrix=c_matrix, display_labels=[\"Not stroke\", \"Stroke\"]\n",
|
||
" )\n",
|
||
" disp.plot(ax=ax.flat[index], cmap=custom_color_map(c_matrix), colorbar=False)\n",
|
||
"\n",
|
||
" disp.ax_.set_title(key)\n",
|
||
"\n",
|
||
"if len(class_models) < len(ax.flat):\n",
|
||
" for i in range(len(class_models), len(ax.flat)):\n",
|
||
" fig.delaxes(ax.flat[i])\n",
|
||
"\n",
|
||
"plt.subplots_adjust(top=0.9, bottom=0.1, hspace=0.4, wspace=0.3)\n",
|
||
"plt.show()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Precision, Recall, Accuracy, F1:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_0e668_row0_col0 {\n",
|
||
" background-color: #1f988b;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_0e668_row0_col1, #T_0e668_row1_col0, #T_0e668_row1_col2, #T_0e668_row2_col3 {\n",
|
||
" background-color: #a8db34;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_0e668_row0_col2, #T_0e668_row0_col3, #T_0e668_row1_col1, #T_0e668_row2_col0 {\n",
|
||
" background-color: #26818e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_0e668_row0_col4 {\n",
|
||
" background-color: #b7318a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_0e668_row0_col5, #T_0e668_row1_col4, #T_0e668_row1_col6, #T_0e668_row2_col7 {\n",
|
||
" background-color: #da5a6a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_0e668_row0_col6, #T_0e668_row0_col7, #T_0e668_row2_col4, #T_0e668_row2_col5 {\n",
|
||
" background-color: #4e02a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_0e668_row1_col3, #T_0e668_row2_col1 {\n",
|
||
" background-color: #1f968b;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_0e668_row1_col5 {\n",
|
||
" background-color: #be3885;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_0e668_row1_col7 {\n",
|
||
" background-color: #9c179e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_0e668_row2_col2 {\n",
|
||
" background-color: #86d549;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_0e668_row2_col6 {\n",
|
||
" background-color: #8808a6;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_0e668\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_0e668_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
||
" <th id=\"T_0e668_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
||
" <th id=\"T_0e668_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
||
" <th id=\"T_0e668_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
||
" <th id=\"T_0e668_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
||
" <th id=\"T_0e668_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_0e668_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
||
" <th id=\"T_0e668_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_0e668_level0_row0\" class=\"row_heading level0 row0\" >mlp</th>\n",
|
||
" <td id=\"T_0e668_row0_col0\" class=\"data row0 col0\" >0.400000</td>\n",
|
||
" <td id=\"T_0e668_row0_col1\" class=\"data row0 col1\" >0.200000</td>\n",
|
||
" <td id=\"T_0e668_row0_col2\" class=\"data row0 col2\" >0.020101</td>\n",
|
||
" <td id=\"T_0e668_row0_col3\" class=\"data row0 col3\" >0.020000</td>\n",
|
||
" <td id=\"T_0e668_row0_col4\" class=\"data row0 col4\" >0.950832</td>\n",
|
||
" <td id=\"T_0e668_row0_col5\" class=\"data row0 col5\" >0.948141</td>\n",
|
||
" <td id=\"T_0e668_row0_col6\" class=\"data row0 col6\" >0.038278</td>\n",
|
||
" <td id=\"T_0e668_row0_col7\" class=\"data row0 col7\" >0.036364</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_0e668_level0_row1\" class=\"row_heading level0 row1\" >knn</th>\n",
|
||
" <td id=\"T_0e668_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_0e668_row1_col1\" class=\"data row1 col1\" >0.117647</td>\n",
|
||
" <td id=\"T_0e668_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_0e668_row1_col3\" class=\"data row1 col3\" >0.120000</td>\n",
|
||
" <td id=\"T_0e668_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_0e668_row1_col5\" class=\"data row1 col5\" >0.912916</td>\n",
|
||
" <td id=\"T_0e668_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_0e668_row1_col7\" class=\"data row1 col7\" >0.118812</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_0e668_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
|
||
" <td id=\"T_0e668_row2_col0\" class=\"data row2 col0\" >0.228869</td>\n",
|
||
" <td id=\"T_0e668_row2_col1\" class=\"data row2 col1\" >0.135135</td>\n",
|
||
" <td id=\"T_0e668_row2_col2\" class=\"data row2 col2\" >0.884422</td>\n",
|
||
" <td id=\"T_0e668_row2_col3\" class=\"data row2 col3\" >0.500000</td>\n",
|
||
" <td id=\"T_0e668_row2_col4\" class=\"data row2 col4\" >0.849315</td>\n",
|
||
" <td id=\"T_0e668_row2_col5\" class=\"data row2 col5\" >0.818982</td>\n",
|
||
" <td id=\"T_0e668_row2_col6\" class=\"data row2 col6\" >0.363636</td>\n",
|
||
" <td id=\"T_0e668_row2_col7\" class=\"data row2 col7\" >0.212766</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x1bdc4916000>"
|
||
]
|
||
},
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
||
" [\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" \"Accuracy_train\",\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_train\",\n",
|
||
" \"F1_test\",\n",
|
||
" ]\n",
|
||
"]\n",
|
||
"class_metrics.sort_values(\n",
|
||
" by=\"Accuracy_test\", ascending=False\n",
|
||
").style.background_gradient(\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Анализ моделей:\n",
|
||
"##### MLP (многослойный перцептрон):\n",
|
||
"Точность (Accuracy): 95% (обучение и тест).\n",
|
||
"Precision и Recall: крайне низкие (0.40 и 0.02 на обучении, 0.20 и 0.02 на тесте).\n",
|
||
"F1-метрика: практически нулевая (0.038 и 0.037).\n",
|
||
"Вывод: модель хорошо определяет общий класс, но почти не замечает положительные примеры.\n",
|
||
"\n",
|
||
"##### KNN (K-ближайшие соседи):\n",
|
||
"Обучение: идеальные метрики (1.0).\n",
|
||
"Тест: резкое падение (Precision 0.118, Recall 0.12, Accuracy 91%).\n",
|
||
"Вывод: переобучение. Отлично работает на обучении, но плохо обобщает на новых данных.\n",
|
||
"\n",
|
||
"##### Random Forest (случайный лес):\n",
|
||
"Accuracy: 85% (обучение) и 82% (тест).\n",
|
||
"Precision и Recall: умеренные, но низкие на тесте (0.135 и 0.50).\n",
|
||
"F1: лучше других моделей (0.213 на тесте).\n",
|
||
"\n",
|
||
"##### Вывод: баланс метрик лучше, чем у других, но точность распознавания положительных примеров всё еще оставляет желать лучшего.\n",
|
||
"Сравнение с baseline:\n",
|
||
"Baseline (простая модель):\n",
|
||
"Accuracy 52%, Precision 0.058, Recall 0.58, F1 0.106.\n",
|
||
"Победитель по Accuracy: все модели значительно превосходят baseline.\n",
|
||
"Recall: Random Forest лучше baseline, MLP и KNN уступают.\n",
|
||
"F1-метрика: Random Forest снова впереди, но до желаемого уровня ещё далеко.\n",
|
||
"\n",
|
||
"##### Заключение:\n",
|
||
"MLP: сильно смещена, игнорирует положительные примеры.\n",
|
||
"KNN: высокая дисперсия, сильно переобучена.\n",
|
||
"Random Forest: самый сбалансированный вариант, но precision нужно улучшать.\n",
|
||
"Итог: Random Forest – лучший выбор из предложенных, но требует доработки."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Регрессия\n",
|
||
"\n",
|
||
"Разделим набор данных на на обучающую и тестовые выборки (80/20). Целевой признак - avg_glucose_level"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>gender</th>\n",
|
||
" <th>age</th>\n",
|
||
" <th>hypertension</th>\n",
|
||
" <th>heart_disease</th>\n",
|
||
" <th>ever_married</th>\n",
|
||
" <th>work_type</th>\n",
|
||
" <th>Residence_type</th>\n",
|
||
" <th>bmi</th>\n",
|
||
" <th>smoking_status</th>\n",
|
||
" <th>stroke</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>id</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>13276</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>38.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>22.6</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>21346</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>12.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>No</td>\n",
|
||
" <td>children</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>17.8</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>59178</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>7.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>No</td>\n",
|
||
" <td>children</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>22.3</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1679</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>35.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>formerly smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1534</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>61.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>26.1</td>\n",
|
||
" <td>smokes</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>30463</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>29.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>No</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>29.4</td>\n",
|
||
" <td>formerly smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>41935</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>34.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>No</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>33.9</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>68483</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>60.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>41.2</td>\n",
|
||
" <td>formerly smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>38617</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>28.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Self-employed</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>29.9</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>46527</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>53.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Govt_job</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>41.9</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>4088 rows × 10 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" gender age hypertension heart_disease ever_married work_type \\\n",
|
||
"id \n",
|
||
"13276 Female 38.0 0 0 Yes Private \n",
|
||
"21346 Female 12.0 0 0 No children \n",
|
||
"59178 Female 7.0 0 0 No children \n",
|
||
"1679 Male 35.0 0 0 Yes Private \n",
|
||
"1534 Female 61.0 0 0 Yes Private \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"30463 Male 29.0 0 0 No Private \n",
|
||
"41935 Male 34.0 0 0 No Private \n",
|
||
"68483 Female 60.0 0 0 Yes Private \n",
|
||
"38617 Male 28.0 0 0 Yes Self-employed \n",
|
||
"46527 Male 53.0 1 1 Yes Govt_job \n",
|
||
"\n",
|
||
" Residence_type bmi smoking_status stroke \n",
|
||
"id \n",
|
||
"13276 Urban 22.6 Unknown 0 \n",
|
||
"21346 Rural 17.8 Unknown 0 \n",
|
||
"59178 Urban 22.3 Unknown 0 \n",
|
||
"1679 Rural NaN formerly smoked 0 \n",
|
||
"1534 Rural 26.1 smokes 0 \n",
|
||
"... ... ... ... ... \n",
|
||
"30463 Urban 29.4 formerly smoked 0 \n",
|
||
"41935 Rural 33.9 never smoked 0 \n",
|
||
"68483 Urban 41.2 formerly smoked 0 \n",
|
||
"38617 Urban 29.9 never smoked 0 \n",
|
||
"46527 Rural 41.9 never smoked 0 \n",
|
||
"\n",
|
||
"[4088 rows x 10 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"id\n",
|
||
"13276 71.06\n",
|
||
"21346 70.13\n",
|
||
"59178 86.75\n",
|
||
"1679 77.48\n",
|
||
"1534 99.35\n",
|
||
" ... \n",
|
||
"30463 82.93\n",
|
||
"41935 125.29\n",
|
||
"68483 65.38\n",
|
||
"38617 73.98\n",
|
||
"46527 109.51\n",
|
||
"Name: avg_glucose_level, Length: 4088, dtype: float64"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>gender</th>\n",
|
||
" <th>age</th>\n",
|
||
" <th>hypertension</th>\n",
|
||
" <th>heart_disease</th>\n",
|
||
" <th>ever_married</th>\n",
|
||
" <th>work_type</th>\n",
|
||
" <th>Residence_type</th>\n",
|
||
" <th>bmi</th>\n",
|
||
" <th>smoking_status</th>\n",
|
||
" <th>stroke</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>id</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>8385</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>37.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>35.9</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>937</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>7.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>No</td>\n",
|
||
" <td>children</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>NaN</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3494</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>80.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>26.7</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>23850</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>66.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>33.1</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>31156</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>49.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>29.8</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>71010</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>80.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>No</td>\n",
|
||
" <td>Self-employed</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>22.8</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>39518</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>20.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>No</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Rural</td>\n",
|
||
" <td>20.7</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>7780</th>\n",
|
||
" <td>Male</td>\n",
|
||
" <td>51.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Self-employed</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>30.7</td>\n",
|
||
" <td>never smoked</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>56137</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>62.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Private</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>36.3</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>33175</th>\n",
|
||
" <td>Female</td>\n",
|
||
" <td>57.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>Yes</td>\n",
|
||
" <td>Govt_job</td>\n",
|
||
" <td>Urban</td>\n",
|
||
" <td>28.5</td>\n",
|
||
" <td>Unknown</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>1022 rows × 10 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" gender age hypertension heart_disease ever_married work_type \\\n",
|
||
"id \n",
|
||
"8385 Male 37.0 0 0 Yes Private \n",
|
||
"937 Male 7.0 0 0 No children \n",
|
||
"3494 Female 80.0 0 0 Yes Private \n",
|
||
"23850 Male 66.0 0 0 Yes Private \n",
|
||
"31156 Female 49.0 0 0 Yes Private \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"71010 Female 80.0 0 0 No Self-employed \n",
|
||
"39518 Female 20.0 0 0 No Private \n",
|
||
"7780 Male 51.0 0 0 Yes Self-employed \n",
|
||
"56137 Female 62.0 0 0 Yes Private \n",
|
||
"33175 Female 57.0 0 0 Yes Govt_job \n",
|
||
"\n",
|
||
" Residence_type bmi smoking_status stroke \n",
|
||
"id \n",
|
||
"8385 Urban 35.9 Unknown 0 \n",
|
||
"937 Urban NaN Unknown 0 \n",
|
||
"3494 Rural 26.7 Unknown 0 \n",
|
||
"23850 Urban 33.1 never smoked 0 \n",
|
||
"31156 Urban 29.8 never smoked 0 \n",
|
||
"... ... ... ... ... \n",
|
||
"71010 Urban 22.8 never smoked 0 \n",
|
||
"39518 Rural 20.7 never smoked 0 \n",
|
||
"7780 Urban 30.7 never smoked 0 \n",
|
||
"56137 Urban 36.3 Unknown 0 \n",
|
||
"33175 Urban 28.5 Unknown 1 \n",
|
||
"\n",
|
||
"[1022 rows x 10 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"id\n",
|
||
"8385 90.78\n",
|
||
"937 87.94\n",
|
||
"3494 102.90\n",
|
||
"23850 103.01\n",
|
||
"31156 105.99\n",
|
||
" ... \n",
|
||
"71010 57.57\n",
|
||
"39518 78.94\n",
|
||
"7780 75.73\n",
|
||
"56137 88.32\n",
|
||
"33175 110.52\n",
|
||
"Name: avg_glucose_level, Length: 1022, dtype: float64"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"features = ['gender', 'age', 'hypertension', 'heart_disease', 'ever_married', 'work_type', 'Residence_type', 'bmi', 'smoking_status', 'stroke']\n",
|
||
"target = 'avg_glucose_level'\n",
|
||
"\n",
|
||
"X_train, X_test, y_train, y_test = train_test_split(df[features], df[target], test_size=0.2, random_state=random_state)\n",
|
||
"\n",
|
||
"display(\"X_train\", X_train)\n",
|
||
"display(\"y_train\", y_train)\n",
|
||
"\n",
|
||
"display(\"X_test\", X_test)\n",
|
||
"display(\"y_test\", y_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Выберем ориентир для задачи регрессии. Для этого применим алгоритм правила нуля, т.е. в каждом случае в качестве предсказания выберем среднее значение из области значений целевого признака."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Baseline RMSE: 44.12711275645952\n",
|
||
"Baseline RMAE: 5.662154850745081\n",
|
||
"Baseline R2: -0.0010729515309222393\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import math\n",
|
||
"from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n",
|
||
"\n",
|
||
"# Базовое предсказание: среднее значение по y_train\n",
|
||
"baseline_predictions = [y_train.mean()] * len(y_test)\n",
|
||
"\n",
|
||
"# Вычисление метрик качества для ориентира\n",
|
||
"baseline_rmse = math.sqrt(\n",
|
||
" mean_squared_error(y_test, baseline_predictions)\n",
|
||
" )\n",
|
||
"baseline_rmae = math.sqrt(\n",
|
||
" mean_absolute_error(y_test, baseline_predictions)\n",
|
||
" )\n",
|
||
"baseline_r2 = r2_score(y_test, baseline_predictions)\n",
|
||
"\n",
|
||
"print('Baseline RMSE:', baseline_rmse)\n",
|
||
"print('Baseline RMAE:', baseline_rmae)\n",
|
||
"print('Baseline R2:', baseline_r2)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Метрики:\n",
|
||
"##### RMSE: корень из MSE, измеряет среднеквадратическую ошибку. \n",
|
||
"Удобен, так как результат в тех же единицах, что и данные. Штрафует за большие отклонения. Хорош для задач, где важна интерпретируемость.\n",
|
||
"##### RMAE: корень из MAE, измеряет среднюю абсолютную ошибку. \n",
|
||
"Менее чувствителен к выбросам, что полезно для данных с редкими сильными отклонениями.\n",
|
||
"##### R²: коэффициент детерминации, показывает, насколько модель объясняет изменчивость данных. Значение ближе к 1 — модель хорошо описывает данные\n",
|
||
"Используется для сравнения моделей на одинаковых данных.\n",
|
||
"Эти метрики помогают оценить, насколько точна модель по сравнению с простым усреднением."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Сформируем конвейер для регрессии"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"columns_to_drop = []\n",
|
||
"columns_not_to_modify = [\"hypertension\", \"heart_disease\", \"stroke\", \"avg_glucose_level\"]\n",
|
||
"\n",
|
||
"num_columns = [\n",
|
||
" column\n",
|
||
" for column in df.columns\n",
|
||
" if column not in columns_to_drop\n",
|
||
" and column not in columns_not_to_modify\n",
|
||
" and df[column].dtype != \"object\"\n",
|
||
"]\n",
|
||
"\n",
|
||
"cat_columns = [\n",
|
||
" column\n",
|
||
" for column in df.columns\n",
|
||
" if column not in columns_to_drop\n",
|
||
" and column not in columns_not_to_modify\n",
|
||
" and df[column].dtype == \"object\"\n",
|
||
"]\n",
|
||
"\n",
|
||
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
||
"num_scaler = StandardScaler()\n",
|
||
"preprocessing_num = Pipeline(\n",
|
||
" [\n",
|
||
" (\"imputer\", num_imputer),\n",
|
||
" (\"scaler\", num_scaler),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
||
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
||
"preprocessing_cat = Pipeline(\n",
|
||
" [\n",
|
||
" (\"imputer\", cat_imputer),\n",
|
||
" (\"encoder\", cat_encoder),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"features_preprocessing = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
|
||
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\"\n",
|
||
")\n",
|
||
"\n",
|
||
"drop_columns = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\",\n",
|
||
")\n",
|
||
"\n",
|
||
"pipeline_end_reg = Pipeline(\n",
|
||
" [\n",
|
||
" (\"features_preprocessing\", features_preprocessing),\n",
|
||
" (\"drop_columns\", drop_columns),\n",
|
||
" ]\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Теперь проверим работу конвейера:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>age</th>\n",
|
||
" <th>bmi</th>\n",
|
||
" <th>gender_Male</th>\n",
|
||
" <th>gender_Other</th>\n",
|
||
" <th>ever_married_Yes</th>\n",
|
||
" <th>work_type_Never_worked</th>\n",
|
||
" <th>work_type_Private</th>\n",
|
||
" <th>work_type_Self-employed</th>\n",
|
||
" <th>work_type_children</th>\n",
|
||
" <th>Residence_type_Urban</th>\n",
|
||
" <th>smoking_status_formerly smoked</th>\n",
|
||
" <th>smoking_status_never smoked</th>\n",
|
||
" <th>smoking_status_smokes</th>\n",
|
||
" <th>hypertension</th>\n",
|
||
" <th>heart_disease</th>\n",
|
||
" <th>stroke</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>id</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>13276</th>\n",
|
||
" <td>-0.236211</td>\n",
|
||
" <td>-0.826056</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>21346</th>\n",
|
||
" <td>-1.386874</td>\n",
|
||
" <td>-1.455413</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>59178</th>\n",
|
||
" <td>-1.608155</td>\n",
|
||
" <td>-0.865391</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1679</th>\n",
|
||
" <td>-0.368980</td>\n",
|
||
" <td>-0.104918</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1534</th>\n",
|
||
" <td>0.781682</td>\n",
|
||
" <td>-0.367150</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>30463</th>\n",
|
||
" <td>-0.634518</td>\n",
|
||
" <td>0.065532</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>41935</th>\n",
|
||
" <td>-0.413236</td>\n",
|
||
" <td>0.655554</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>68483</th>\n",
|
||
" <td>0.737426</td>\n",
|
||
" <td>1.612701</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>38617</th>\n",
|
||
" <td>-0.678774</td>\n",
|
||
" <td>0.131090</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>46527</th>\n",
|
||
" <td>0.427632</td>\n",
|
||
" <td>1.704482</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>4088 rows × 16 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" age bmi gender_Male gender_Other ever_married_Yes \\\n",
|
||
"id \n",
|
||
"13276 -0.236211 -0.826056 0.0 0.0 1.0 \n",
|
||
"21346 -1.386874 -1.455413 0.0 0.0 0.0 \n",
|
||
"59178 -1.608155 -0.865391 0.0 0.0 0.0 \n",
|
||
"1679 -0.368980 -0.104918 1.0 0.0 1.0 \n",
|
||
"1534 0.781682 -0.367150 0.0 0.0 1.0 \n",
|
||
"... ... ... ... ... ... \n",
|
||
"30463 -0.634518 0.065532 1.0 0.0 0.0 \n",
|
||
"41935 -0.413236 0.655554 1.0 0.0 0.0 \n",
|
||
"68483 0.737426 1.612701 0.0 0.0 1.0 \n",
|
||
"38617 -0.678774 0.131090 1.0 0.0 1.0 \n",
|
||
"46527 0.427632 1.704482 1.0 0.0 1.0 \n",
|
||
"\n",
|
||
" work_type_Never_worked work_type_Private work_type_Self-employed \\\n",
|
||
"id \n",
|
||
"13276 0.0 1.0 0.0 \n",
|
||
"21346 0.0 0.0 0.0 \n",
|
||
"59178 0.0 0.0 0.0 \n",
|
||
"1679 0.0 1.0 0.0 \n",
|
||
"1534 0.0 1.0 0.0 \n",
|
||
"... ... ... ... \n",
|
||
"30463 0.0 1.0 0.0 \n",
|
||
"41935 0.0 1.0 0.0 \n",
|
||
"68483 0.0 1.0 0.0 \n",
|
||
"38617 0.0 0.0 1.0 \n",
|
||
"46527 0.0 0.0 0.0 \n",
|
||
"\n",
|
||
" work_type_children Residence_type_Urban \\\n",
|
||
"id \n",
|
||
"13276 0.0 1.0 \n",
|
||
"21346 1.0 0.0 \n",
|
||
"59178 1.0 1.0 \n",
|
||
"1679 0.0 0.0 \n",
|
||
"1534 0.0 0.0 \n",
|
||
"... ... ... \n",
|
||
"30463 0.0 1.0 \n",
|
||
"41935 0.0 0.0 \n",
|
||
"68483 0.0 1.0 \n",
|
||
"38617 0.0 1.0 \n",
|
||
"46527 0.0 0.0 \n",
|
||
"\n",
|
||
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
|
||
"id \n",
|
||
"13276 0.0 0.0 \n",
|
||
"21346 0.0 0.0 \n",
|
||
"59178 0.0 0.0 \n",
|
||
"1679 1.0 0.0 \n",
|
||
"1534 0.0 0.0 \n",
|
||
"... ... ... \n",
|
||
"30463 1.0 0.0 \n",
|
||
"41935 0.0 1.0 \n",
|
||
"68483 1.0 0.0 \n",
|
||
"38617 0.0 1.0 \n",
|
||
"46527 0.0 1.0 \n",
|
||
"\n",
|
||
" smoking_status_smokes hypertension heart_disease stroke \n",
|
||
"id \n",
|
||
"13276 0.0 0 0 0 \n",
|
||
"21346 0.0 0 0 0 \n",
|
||
"59178 0.0 0 0 0 \n",
|
||
"1679 0.0 0 0 0 \n",
|
||
"1534 1.0 0 0 0 \n",
|
||
"... ... ... ... ... \n",
|
||
"30463 0.0 0 0 0 \n",
|
||
"41935 0.0 0 0 0 \n",
|
||
"68483 0.0 0 0 0 \n",
|
||
"38617 0.0 0 0 0 \n",
|
||
"46527 0.0 1 1 0 \n",
|
||
"\n",
|
||
"[4088 rows x 16 columns]"
|
||
]
|
||
},
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"preprocessing_result = pipeline_end_reg.fit_transform(X_train)\n",
|
||
"preprocessed_df = pd.DataFrame(\n",
|
||
" preprocessing_result,\n",
|
||
" columns=pipeline_end_reg.get_feature_names_out(),\n",
|
||
")\n",
|
||
"\n",
|
||
"preprocessed_df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Подберем оптимальные гиперпараметры для каждой из выбранных моделей методом поиска по сетке и сформируем их набор.\n",
|
||
"\n",
|
||
"knn -- k-ближайших соседей\n",
|
||
"\n",
|
||
"random_forest -- метод случайного леса (набор деревьев решений)\n",
|
||
"\n",
|
||
"mlp -- многослойный персептрон (нейронная сеть)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"d:\\code\\mai\\labs\\AIM-PIbd-31-Bakalskaya-E-D\\lab_4\\venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
|
||
" _data = np.array(data, dtype=dtype, copy=copy,\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Лучшие параметры для knn: {'n_jobs': -1, 'n_neighbors': 30, 'weights': 'uniform'}\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"d:\\code\\mai\\labs\\AIM-PIbd-31-Bakalskaya-E-D\\lab_4\\venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
|
||
" _data = np.array(data, dtype=dtype, copy=copy,\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Лучшие параметры для random_forest: {'criterion': 'squared_error', 'max_depth': 7, 'max_features': 'sqrt', 'n_estimators': 250, 'n_jobs': -1, 'random_state': 9}\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Словарь с вариантами гиперпараметров для каждой модели\n",
|
||
"param_grids = {\n",
|
||
" \"knn\": {\n",
|
||
" \"n_neighbors\": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], \n",
|
||
" \"weights\": ['uniform', 'distance'],\n",
|
||
" \"n_jobs\": [-1]\n",
|
||
" },\n",
|
||
" \"random_forest\": {\n",
|
||
" \"n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
|
||
" \"max_features\": [\"sqrt\", \"log2\", 2],\n",
|
||
" \"max_depth\": [2, 3, 4, 5, 6, 7, 8, 9, 10],\n",
|
||
" \"criterion\": [\"squared_error\", \"absolute_error\", \"poisson\"],\n",
|
||
" \"random_state\": [random_state],\n",
|
||
" \"n_jobs\": [-1]\n",
|
||
" },\n",
|
||
" \"mlp\": {\n",
|
||
" \"solver\": ['adam'], \n",
|
||
" \"max_iter\": [1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000], \n",
|
||
" \"alpha\": 10.0 ** -np.arange(1, 10), \n",
|
||
" \"hidden_layer_sizes\":np.arange(10, 15), \n",
|
||
" \"early_stopping\": [True, False],\n",
|
||
" \"random_state\": [random_state]\n",
|
||
" }\n",
|
||
"}\n",
|
||
"\n",
|
||
"# Создаем экземпляры моделей\n",
|
||
"models = {\n",
|
||
" \"knn\": neighbors.KNeighborsRegressor(),\n",
|
||
" \"random_forest\": ensemble.RandomForestRegressor(),\n",
|
||
" \"mlp\": neural_network.MLPRegressor()\n",
|
||
"}\n",
|
||
"\n",
|
||
"# Словарь для хранения моделей с их лучшими параметрами\n",
|
||
"class_models = {}\n",
|
||
"\n",
|
||
"# Выполнение поиска по сетке для каждой модели\n",
|
||
"for model_name, model in models.items():\n",
|
||
" # Создаем GridSearchCV для текущей модели\n",
|
||
" gs_optimizer = GridSearchCV(estimator=model, param_grid=param_grids[model_name], scoring='neg_mean_squared_error', n_jobs=-1)\n",
|
||
" \n",
|
||
" # Обучаем GridSearchCV\n",
|
||
" gs_optimizer.fit(preprocessed_df, y_train.values.ravel())\n",
|
||
" \n",
|
||
" # Получаем лучшие параметры\n",
|
||
" best_params = gs_optimizer.best_params_\n",
|
||
" print(f\"Лучшие параметры для {model_name}: {best_params}\")\n",
|
||
" \n",
|
||
" class_models[model_name] = {\n",
|
||
" \"model\": model.set_params(**best_params) # Настраиваем модель с лучшими параметрами\n",
|
||
" }"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Далее обучим модели и оценим их качество."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"for model_name in class_models.keys():\n",
|
||
" print(f\"Model: {model_name}\")\n",
|
||
" \n",
|
||
" model = class_models[model_name][\"model\"]\n",
|
||
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end_reg), (\"model\", model)])\n",
|
||
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
|
||
"\n",
|
||
" y_train_pred = model_pipeline.predict(X_train)\n",
|
||
" y_test_pred = model_pipeline.predict(X_test)\n",
|
||
"\n",
|
||
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
||
" class_models[model_name][\"train_preds\"] = y_train_pred\n",
|
||
" class_models[model_name][\"preds\"] = y_test_pred\n",
|
||
" \n",
|
||
" class_models[model_name][\"RMSE_train\"] = math.sqrt(\n",
|
||
" mean_squared_error(y_train, y_train_pred)\n",
|
||
" )\n",
|
||
" class_models[model_name][\"RMSE_test\"] = math.sqrt(\n",
|
||
" mean_squared_error(y_test, y_test_pred)\n",
|
||
" )\n",
|
||
" class_models[model_name][\"RMAE_test\"] = math.sqrt(\n",
|
||
" mean_absolute_error(y_test, y_test_pred)\n",
|
||
" )\n",
|
||
" class_models[model_name][\"R2_test\"] = r2_score(y_test, y_test_pred)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"RMSE, RMAE, R2:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"reg_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
||
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
|
||
"]\n",
|
||
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
|
||
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
|
||
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Результаты графиками:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Создаем графики для всех моделей\n",
|
||
"for model_name, model_data in class_models.items():\n",
|
||
" print(f\"Model: {model_name}\")\n",
|
||
" y_pred = model_data[\"preds\"]\n",
|
||
" plt.figure(figsize=(10, 6))\n",
|
||
" plt.scatter(y_test, y_pred, alpha=0.5)\n",
|
||
" plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2)\n",
|
||
" plt.xlabel('Фактический уровень глюкозы')\n",
|
||
" plt.ylabel('Прогнозируемый уровень глюкозы')\n",
|
||
" plt.title(f\"Model: {model_name}\")\n",
|
||
" plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"На представленных графиках можно заметить, что модели в целом не демонстрируют высокого качества. Визуализация их предсказаний показывает сильное рассеивание вокруг идеальной линии y = x, что указывает на значительные отклонения предсказаний от фактических значений.\n",
|
||
"\n",
|
||
"Тем не менее ориентир, хоть возможно и не столь значительно, каждая из моделей превосходит по всем показателям. Особенно заметное улучшение в \n",
|
||
"R2, которая переходит из отрицательного значения в положительное, что говорит о том, что модели хотя бы частично объясняют дисперсию данных. \n",
|
||
"\n",
|
||
"Кроме того, можно сказать, что все модели имеет умеренную дисперсию и не сильно подвержены переобучению, потому что разница между RMSE на обучении и тесте незначительна.\n",
|
||
"\n",
|
||
"Итоговые выводы:\n",
|
||
"- Наиболее качественная модель: MLP, так как она показывает наименьшее значение RMSE и наибольшее значение R2, что указывает на лучшую точность и объяснение дисперсии целевой переменной.\n",
|
||
"\n",
|
||
"- Random Forest: Близок по производительности к MLP, с чуть большим RMSE, но является более устойчивой моделью с небольшими отклонениями между обучением и тестом.\n",
|
||
"\n",
|
||
"- KNN: Худшая модель, демонстрирующая наибольшие ошибки и низкое R2, что указывает на необходимость улучшения или использования другой модели для данной задачи."
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "venv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.12.6"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|