MII/lec4_1.ipynb
2024-11-15 23:06:57 +04:00

2491 lines
220 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Загрузка набора данных"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>9046</th>\n",
" <td>Male</td>\n",
" <td>67.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>228.69</td>\n",
" <td>36.6</td>\n",
" <td>formerly smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>51676</th>\n",
" <td>Female</td>\n",
" <td>61.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>202.21</td>\n",
" <td>NaN</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31112</th>\n",
" <td>Male</td>\n",
" <td>80.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>105.92</td>\n",
" <td>32.5</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>60182</th>\n",
" <td>Female</td>\n",
" <td>49.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>171.23</td>\n",
" <td>34.4</td>\n",
" <td>smokes</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1665</th>\n",
" <td>Female</td>\n",
" <td>79.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>174.12</td>\n",
" <td>24.0</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18234</th>\n",
" <td>Female</td>\n",
" <td>80.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>83.75</td>\n",
" <td>NaN</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>44873</th>\n",
" <td>Female</td>\n",
" <td>81.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>125.20</td>\n",
" <td>40.0</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19723</th>\n",
" <td>Female</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>82.99</td>\n",
" <td>30.6</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>37544</th>\n",
" <td>Male</td>\n",
" <td>51.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>166.29</td>\n",
" <td>25.6</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>44679</th>\n",
" <td>Female</td>\n",
" <td>44.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Govt_job</td>\n",
" <td>Urban</td>\n",
" <td>85.28</td>\n",
" <td>26.2</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5110 rows × 11 columns</p>\n",
"</div>"
],
"text/plain": [
" gender age hypertension heart_disease ever_married work_type \\\n",
"id \n",
"9046 Male 67.0 0 1 Yes Private \n",
"51676 Female 61.0 0 0 Yes Self-employed \n",
"31112 Male 80.0 0 1 Yes Private \n",
"60182 Female 49.0 0 0 Yes Private \n",
"1665 Female 79.0 1 0 Yes Self-employed \n",
"... ... ... ... ... ... ... \n",
"18234 Female 80.0 1 0 Yes Private \n",
"44873 Female 81.0 0 0 Yes Self-employed \n",
"19723 Female 35.0 0 0 Yes Self-employed \n",
"37544 Male 51.0 0 0 Yes Private \n",
"44679 Female 44.0 0 0 Yes Govt_job \n",
"\n",
" Residence_type avg_glucose_level bmi smoking_status stroke \n",
"id \n",
"9046 Urban 228.69 36.6 formerly smoked 1 \n",
"51676 Rural 202.21 NaN never smoked 1 \n",
"31112 Rural 105.92 32.5 never smoked 1 \n",
"60182 Urban 171.23 34.4 smokes 1 \n",
"1665 Rural 174.12 24.0 never smoked 1 \n",
"... ... ... ... ... ... \n",
"18234 Urban 83.75 NaN never smoked 0 \n",
"44873 Urban 125.20 40.0 never smoked 0 \n",
"19723 Rural 82.99 30.6 never smoked 0 \n",
"37544 Rural 166.29 25.6 formerly smoked 0 \n",
"44679 Urban 85.28 26.2 Unknown 0 \n",
"\n",
"[5110 rows x 11 columns]"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"random_state=9\n",
"\n",
"df = pd.read_csv(\"data/healthcare.csv\", index_col=\"id\")\n",
"\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
"\n",
"Целевой признак -- heart_disease - есть ли заболевания сердца\n",
". x - полная выборка, y - gear box столбец\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>17762</th>\n",
" <td>Female</td>\n",
" <td>3.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>children</td>\n",
" <td>Rural</td>\n",
" <td>114.88</td>\n",
" <td>19.1</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48652</th>\n",
" <td>Female</td>\n",
" <td>8.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>children</td>\n",
" <td>Urban</td>\n",
" <td>83.55</td>\n",
" <td>22.4</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6903</th>\n",
" <td>Female</td>\n",
" <td>15.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>children</td>\n",
" <td>Rural</td>\n",
" <td>77.57</td>\n",
" <td>18.3</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2903</th>\n",
" <td>Female</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>123.83</td>\n",
" <td>23.8</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>58153</th>\n",
" <td>Female</td>\n",
" <td>18.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>123.66</td>\n",
" <td>22.2</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>34084</th>\n",
" <td>Male</td>\n",
" <td>7.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>children</td>\n",
" <td>Urban</td>\n",
" <td>77.12</td>\n",
" <td>18.6</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11176</th>\n",
" <td>Male</td>\n",
" <td>9.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>children</td>\n",
" <td>Rural</td>\n",
" <td>85.02</td>\n",
" <td>16.3</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>52554</th>\n",
" <td>Male</td>\n",
" <td>19.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>64.92</td>\n",
" <td>22.5</td>\n",
" <td>Unknown</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10381</th>\n",
" <td>Female</td>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>91.00</td>\n",
" <td>33.3</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>70884</th>\n",
" <td>Female</td>\n",
" <td>34.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>79.80</td>\n",
" <td>37.4</td>\n",
" <td>smokes</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4088 rows × 11 columns</p>\n",
"</div>"
],
"text/plain": [
" gender age hypertension heart_disease ever_married work_type \\\n",
"id \n",
"17762 Female 3.0 0 0 No children \n",
"48652 Female 8.0 0 0 No children \n",
"6903 Female 15.0 0 0 No children \n",
"2903 Female 35.0 0 0 No Private \n",
"58153 Female 18.0 0 0 No Private \n",
"... ... ... ... ... ... ... \n",
"34084 Male 7.0 0 0 No children \n",
"11176 Male 9.0 0 0 No children \n",
"52554 Male 19.0 0 0 No Private \n",
"10381 Female 38.0 1 0 Yes Self-employed \n",
"70884 Female 34.0 0 0 Yes Private \n",
"\n",
" Residence_type avg_glucose_level bmi smoking_status stroke \n",
"id \n",
"17762 Rural 114.88 19.1 Unknown 0 \n",
"48652 Urban 83.55 22.4 Unknown 0 \n",
"6903 Rural 77.57 18.3 Unknown 0 \n",
"2903 Rural 123.83 23.8 never smoked 0 \n",
"58153 Urban 123.66 22.2 never smoked 0 \n",
"... ... ... ... ... ... \n",
"34084 Urban 77.12 18.6 Unknown 0 \n",
"11176 Rural 85.02 16.3 Unknown 0 \n",
"52554 Rural 64.92 22.5 Unknown 0 \n",
"10381 Urban 91.00 33.3 never smoked 0 \n",
"70884 Urban 79.80 37.4 smokes 0 \n",
"\n",
"[4088 rows x 11 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>heart_disease</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>17762</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48652</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6903</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2903</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>58153</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>34084</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11176</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>52554</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10381</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>70884</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4088 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" heart_disease\n",
"id \n",
"17762 0\n",
"48652 0\n",
"6903 0\n",
"2903 0\n",
"58153 0\n",
"... ...\n",
"34084 0\n",
"11176 0\n",
"52554 0\n",
"10381 0\n",
"70884 0\n",
"\n",
"[4088 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>gender</th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>ever_married</th>\n",
" <th>work_type</th>\n",
" <th>Residence_type</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>smoking_status</th>\n",
" <th>stroke</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2520</th>\n",
" <td>Female</td>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>84.90</td>\n",
" <td>26.2</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>56855</th>\n",
" <td>Male</td>\n",
" <td>46.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>137.77</td>\n",
" <td>29.3</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27034</th>\n",
" <td>Female</td>\n",
" <td>65.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Govt_job</td>\n",
" <td>Urban</td>\n",
" <td>82.72</td>\n",
" <td>29.8</td>\n",
" <td>smokes</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>641</th>\n",
" <td>Male</td>\n",
" <td>52.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Govt_job</td>\n",
" <td>Rural</td>\n",
" <td>87.26</td>\n",
" <td>40.1</td>\n",
" <td>smokes</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>65407</th>\n",
" <td>Female</td>\n",
" <td>64.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>65.46</td>\n",
" <td>32.5</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40447</th>\n",
" <td>Female</td>\n",
" <td>59.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Private</td>\n",
" <td>Rural</td>\n",
" <td>82.42</td>\n",
" <td>28.8</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>56324</th>\n",
" <td>Female</td>\n",
" <td>53.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Rural</td>\n",
" <td>81.76</td>\n",
" <td>34.3</td>\n",
" <td>formerly smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4813</th>\n",
" <td>Male</td>\n",
" <td>27.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>No</td>\n",
" <td>Private</td>\n",
" <td>Urban</td>\n",
" <td>112.98</td>\n",
" <td>44.7</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14372</th>\n",
" <td>Male</td>\n",
" <td>50.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Self-employed</td>\n",
" <td>Urban</td>\n",
" <td>192.16</td>\n",
" <td>43.6</td>\n",
" <td>never smoked</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50522</th>\n",
" <td>Female</td>\n",
" <td>72.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Yes</td>\n",
" <td>Govt_job</td>\n",
" <td>Urban</td>\n",
" <td>131.41</td>\n",
" <td>28.4</td>\n",
" <td>never smoked</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1022 rows × 11 columns</p>\n",
"</div>"
],
"text/plain": [
" gender age hypertension heart_disease ever_married work_type \\\n",
"id \n",
"2520 Female 26.0 0 0 Yes Private \n",
"56855 Male 46.0 0 0 Yes Private \n",
"27034 Female 65.0 0 0 Yes Govt_job \n",
"641 Male 52.0 0 0 Yes Govt_job \n",
"65407 Female 64.0 0 0 Yes Self-employed \n",
"... ... ... ... ... ... ... \n",
"40447 Female 59.0 0 0 Yes Private \n",
"56324 Female 53.0 0 0 Yes Self-employed \n",
"4813 Male 27.0 0 0 No Private \n",
"14372 Male 50.0 0 0 Yes Self-employed \n",
"50522 Female 72.0 0 0 Yes Govt_job \n",
"\n",
" Residence_type avg_glucose_level bmi smoking_status stroke \n",
"id \n",
"2520 Rural 84.90 26.2 never smoked 0 \n",
"56855 Urban 137.77 29.3 never smoked 0 \n",
"27034 Urban 82.72 29.8 smokes 0 \n",
"641 Rural 87.26 40.1 smokes 0 \n",
"65407 Rural 65.46 32.5 formerly smoked 0 \n",
"... ... ... ... ... ... \n",
"40447 Rural 82.42 28.8 never smoked 0 \n",
"56324 Rural 81.76 34.3 formerly smoked 0 \n",
"4813 Urban 112.98 44.7 never smoked 0 \n",
"14372 Urban 192.16 43.6 never smoked 0 \n",
"50522 Urban 131.41 28.4 never smoked 1 \n",
"\n",
"[1022 rows x 11 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>heart_disease</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2520</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>56855</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27034</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>641</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>65407</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40447</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>56324</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4813</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14372</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50522</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1022 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" heart_disease\n",
"id \n",
"2520 0\n",
"56855 0\n",
"27034 0\n",
"641 0\n",
"65407 0\n",
"... ...\n",
"40447 0\n",
"56324 0\n",
"4813 0\n",
"14372 0\n",
"50522 0\n",
"\n",
"[1022 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from utils import split_stratified_into_train_val_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"heart_disease\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В итоге, этот код выполняет следующие действия:\n",
"\n",
"* Заполняет пропущенные значения: В числовых столбцах медианой, в категориальных - значением \"unknown\".\n",
"* Стандартизирует числовые данные: приводит их к нулевому среднему и единичному стандартному отклонению.\n",
"* Преобразует категориальные данные: использует one-hot-кодирование.\n",
"* Удаляет ненужные столбцы: из списка `columns_to_drop`.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование конвейера для классификации данных\n",
"\n",
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
"\n",
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
"\n",
"features_preprocessing -- трансформер для предобработки признаков\n",
"\n",
"features_engineering -- трансформер для конструирования признаков\n",
"\n",
"drop_columns -- трансформер для удаления колонок\n",
"\n",
"features_postprocessing -- трансформер для унитарного кодирования новых признаков\n",
"\n",
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков\n",
"\n",
"Конвейер выполняется последовательно.\n",
"\n",
"Трансформер выполняет параллельно для указанного набора колонок.\n",
"\n",
"Документация: \n",
"\n",
"https://scikit-learn.org/1.5/api/sklearn.pipeline.html\n",
"\n",
"https://scikit-learn.org/1.5/modules/generated/sklearn.compose.ColumnTransformer.html#sklearn.compose.ColumnTransformer"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"from transformers import TitanicFeatures\n",
"\n",
"\n",
"columns_to_drop = []\n",
"#columns_to_drop = [\"Doors\", \"Color\", \"Gear box type\", \"Prod_year\", \"Mileage\", \"Airbags\", \"Levy\", \"Leather_interior\", \"Fuel type\", \"Drive wheels\"]\n",
"num_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop and df[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop and df[column].dtype == \"object\"\n",
"]\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" #(\"prepocessing_features\", cat_imputer, [\"Name\", \"Cabin\"]),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"# features_engineering = ColumnTransformer(\n",
"# verbose_feature_names_out=False,\n",
"# transformers=[\n",
"# (\"add_features\", TitanicFeatures(), [\"Name\", \"Cabin\"]),\n",
"# ],\n",
"# remainder=\"passthrough\",\n",
"# )\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"# features_postprocessing = ColumnTransformer(\n",
"# verbose_feature_names_out=False,\n",
"# transformers=[\n",
"# (\"prepocessing_cat\", preprocessing_cat, [\"Cabin_type\"]),\n",
"# ],\n",
"# remainder=\"passthrough\",\n",
"# )\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" # (\"features_engineering\", features_engineering),\n",
" (\"drop_columns\", drop_columns),\n",
" # (\"features_postprocessing\", features_postprocessing),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Демонстрация работы конвейера для предобработки данных при классификации"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>hypertension</th>\n",
" <th>heart_disease</th>\n",
" <th>avg_glucose_level</th>\n",
" <th>bmi</th>\n",
" <th>stroke</th>\n",
" <th>gender_Male</th>\n",
" <th>ever_married_Yes</th>\n",
" <th>work_type_Never_worked</th>\n",
" <th>work_type_Private</th>\n",
" <th>work_type_Self-employed</th>\n",
" <th>work_type_children</th>\n",
" <th>Residence_type_Urban</th>\n",
" <th>smoking_status_formerly smoked</th>\n",
" <th>smoking_status_never smoked</th>\n",
" <th>smoking_status_smokes</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>17762</th>\n",
" <td>-1.767059</td>\n",
" <td>-0.331155</td>\n",
" <td>-0.239061</td>\n",
" <td>0.202705</td>\n",
" <td>-1.260486</td>\n",
" <td>-0.221387</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48652</th>\n",
" <td>-1.546212</td>\n",
" <td>-0.331155</td>\n",
" <td>-0.239061</td>\n",
" <td>-0.493692</td>\n",
" <td>-0.834786</td>\n",
" <td>-0.221387</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6903</th>\n",
" <td>-1.237027</td>\n",
" <td>-0.331155</td>\n",
" <td>-0.239061</td>\n",
" <td>-0.626614</td>\n",
" <td>-1.363686</td>\n",
" <td>-0.221387</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2903</th>\n",
" <td>-0.353640</td>\n",
" <td>-0.331155</td>\n",
" <td>-0.239061</td>\n",
" <td>0.401643</td>\n",
" <td>-0.654186</td>\n",
" <td>-0.221387</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>58153</th>\n",
" <td>-1.104519</td>\n",
" <td>-0.331155</td>\n",
" <td>-0.239061</td>\n",
" <td>0.397865</td>\n",
" <td>-0.860586</td>\n",
" <td>-0.221387</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>34084</th>\n",
" <td>-1.590381</td>\n",
" <td>-0.331155</td>\n",
" <td>-0.239061</td>\n",
" <td>-0.636617</td>\n",
" <td>-1.324986</td>\n",
" <td>-0.221387</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11176</th>\n",
" <td>-1.502043</td>\n",
" <td>-0.331155</td>\n",
" <td>-0.239061</td>\n",
" <td>-0.461017</td>\n",
" <td>-1.621686</td>\n",
" <td>-0.221387</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>52554</th>\n",
" <td>-1.060349</td>\n",
" <td>-0.331155</td>\n",
" <td>-0.239061</td>\n",
" <td>-0.907796</td>\n",
" <td>-0.821886</td>\n",
" <td>-0.221387</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10381</th>\n",
" <td>-0.221132</td>\n",
" <td>3.019737</td>\n",
" <td>-0.239061</td>\n",
" <td>-0.328095</td>\n",
" <td>0.571314</td>\n",
" <td>-0.221387</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>70884</th>\n",
" <td>-0.397810</td>\n",
" <td>-0.331155</td>\n",
" <td>-0.239061</td>\n",
" <td>-0.577046</td>\n",
" <td>1.100214</td>\n",
" <td>-0.221387</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4088 rows × 16 columns</p>\n",
"</div>"
],
"text/plain": [
" age hypertension heart_disease avg_glucose_level bmi \\\n",
"id \n",
"17762 -1.767059 -0.331155 -0.239061 0.202705 -1.260486 \n",
"48652 -1.546212 -0.331155 -0.239061 -0.493692 -0.834786 \n",
"6903 -1.237027 -0.331155 -0.239061 -0.626614 -1.363686 \n",
"2903 -0.353640 -0.331155 -0.239061 0.401643 -0.654186 \n",
"58153 -1.104519 -0.331155 -0.239061 0.397865 -0.860586 \n",
"... ... ... ... ... ... \n",
"34084 -1.590381 -0.331155 -0.239061 -0.636617 -1.324986 \n",
"11176 -1.502043 -0.331155 -0.239061 -0.461017 -1.621686 \n",
"52554 -1.060349 -0.331155 -0.239061 -0.907796 -0.821886 \n",
"10381 -0.221132 3.019737 -0.239061 -0.328095 0.571314 \n",
"70884 -0.397810 -0.331155 -0.239061 -0.577046 1.100214 \n",
"\n",
" stroke gender_Male ever_married_Yes work_type_Never_worked \\\n",
"id \n",
"17762 -0.221387 0.0 0.0 0.0 \n",
"48652 -0.221387 0.0 0.0 0.0 \n",
"6903 -0.221387 0.0 0.0 0.0 \n",
"2903 -0.221387 0.0 0.0 0.0 \n",
"58153 -0.221387 0.0 0.0 0.0 \n",
"... ... ... ... ... \n",
"34084 -0.221387 1.0 0.0 0.0 \n",
"11176 -0.221387 1.0 0.0 0.0 \n",
"52554 -0.221387 1.0 0.0 0.0 \n",
"10381 -0.221387 0.0 1.0 0.0 \n",
"70884 -0.221387 0.0 1.0 0.0 \n",
"\n",
" work_type_Private work_type_Self-employed work_type_children \\\n",
"id \n",
"17762 0.0 0.0 1.0 \n",
"48652 0.0 0.0 1.0 \n",
"6903 0.0 0.0 1.0 \n",
"2903 1.0 0.0 0.0 \n",
"58153 1.0 0.0 0.0 \n",
"... ... ... ... \n",
"34084 0.0 0.0 1.0 \n",
"11176 0.0 0.0 1.0 \n",
"52554 1.0 0.0 0.0 \n",
"10381 0.0 1.0 0.0 \n",
"70884 1.0 0.0 0.0 \n",
"\n",
" Residence_type_Urban smoking_status_formerly smoked \\\n",
"id \n",
"17762 0.0 0.0 \n",
"48652 1.0 0.0 \n",
"6903 0.0 0.0 \n",
"2903 0.0 0.0 \n",
"58153 1.0 0.0 \n",
"... ... ... \n",
"34084 1.0 0.0 \n",
"11176 0.0 0.0 \n",
"52554 0.0 0.0 \n",
"10381 1.0 0.0 \n",
"70884 1.0 0.0 \n",
"\n",
" smoking_status_never smoked smoking_status_smokes \n",
"id \n",
"17762 0.0 0.0 \n",
"48652 0.0 0.0 \n",
"6903 0.0 0.0 \n",
"2903 1.0 0.0 \n",
"58153 1.0 0.0 \n",
"... ... ... \n",
"34084 0.0 0.0 \n",
"11176 0.0 0.0 \n",
"52554 0.0 0.0 \n",
"10381 1.0 0.0 \n",
"70884 0.0 1.0 \n",
"\n",
"[4088 rows x 16 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование набора моделей для классификации\n",
"\n",
"logistic -- логистическая регрессия\n",
"\n",
"ridge -- гребневая регрессия\n",
"\n",
"decision_tree -- дерево решений\n",
"\n",
"knn -- k-ближайших соседей\n",
"\n",
"naive_bayes -- наивный Байесовский классификатор\n",
"\n",
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
"\n",
"random_forest -- метод случайного леса (набор деревьев решений)\n",
"\n",
"mlp -- многослойный персептрон (нейронная сеть)\n",
"\n",
"Документация: https://scikit-learn.org/1.5/supervised_learning.html"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=100000,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" heart_disease\n",
"id \n",
"17762 0\n",
"48652 0\n",
"6903 0\n",
"2903 0\n",
"58153 0\n",
" heart_disease\n",
"id \n",
"2520 0\n",
"56855 0\n",
"27034 0\n",
"641 0\n",
"65407 0\n"
]
}
],
"source": [
"# print(y_train.dtypes)\n",
"# print(y_test.dtypes)\n",
"# df.info()\n",
"print(y_train.head())\n",
"print(y_test.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение моделей на обучающем наборе данных и оценка на тестовом"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"[0 0 0 ... 0 0 0]\n",
"[0 0 0 ... 0 0 0]\n",
"Model: ridge\n",
"[0 0 0 ... 0 0 0]\n",
"[0 0 0 ... 0 0 0]\n",
"Model: decision_tree\n",
"[0 0 0 ... 0 0 0]\n",
"[0 0 0 ... 0 0 0]\n",
"Model: knn\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\1\\Desktop\\улгту\\3 курс\\МИИ\\mai\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\1\\Desktop\\улгту\\3 курс\\МИИ\\mai\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\1\\Desktop\\улгту\\3 курс\\МИИ\\mai\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\1\\Desktop\\улгту\\3 курс\\МИИ\\mai\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\1\\Desktop\\улгту\\3 курс\\МИИ\\mai\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 0 0 ... 0 0 0]\n",
"[0 0 0 ... 0 0 0]\n",
"Model: naive_bayes\n",
"[0 0 0 ... 0 0 0]\n",
"[0 0 0 ... 0 0 0]\n",
"Model: gradient_boosting\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\1\\Desktop\\улгту\\3 курс\\МИИ\\mai\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 0 0 ... 0 0 0]\n",
"[0 0 0 ... 0 0 0]\n",
"Model: random_forest\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\1\\Desktop\\улгту\\3 курс\\МИИ\\mai\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\1\\Desktop\\улгту\\3 курс\\МИИ\\mai\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 0 0 ... 0 0 0]\n",
"[0 0 0 ... 0 0 0]\n",
"Model: mlp\n",
"[0 0 0 ... 0 0 0]\n",
"[0 0 0 ... 0 0 0]\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" print(y_train_predict)\n",
" print(y_test_predict)\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict, \n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict, \n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict, \n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Сводная таблица оценок качества для использованных моделей классификации\n",
"\n",
"Документация: https://scikit-learn.org/1.5/modules/model_evaluation.html"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Матрица неточностей"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"bad heart\", \"nice heart\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_4edf2_row0_col0, #T_4edf2_row0_col1, #T_4edf2_row0_col2, #T_4edf2_row0_col3, #T_4edf2_row1_col0, #T_4edf2_row1_col1, #T_4edf2_row1_col2, #T_4edf2_row1_col3, #T_4edf2_row2_col0, #T_4edf2_row2_col1, #T_4edf2_row2_col2, #T_4edf2_row2_col3, #T_4edf2_row3_col0, #T_4edf2_row3_col1, #T_4edf2_row3_col2, #T_4edf2_row3_col3, #T_4edf2_row4_col0, #T_4edf2_row4_col1, #T_4edf2_row4_col2, #T_4edf2_row4_col3, #T_4edf2_row5_col0, #T_4edf2_row5_col1, #T_4edf2_row5_col2, #T_4edf2_row5_col3, #T_4edf2_row6_col0, #T_4edf2_row6_col1, #T_4edf2_row6_col2, #T_4edf2_row6_col3, #T_4edf2_row7_col0, #T_4edf2_row7_col1, #T_4edf2_row7_col2, #T_4edf2_row7_col3 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4edf2_row0_col4, #T_4edf2_row0_col5, #T_4edf2_row0_col6, #T_4edf2_row0_col7, #T_4edf2_row1_col4, #T_4edf2_row1_col5, #T_4edf2_row1_col6, #T_4edf2_row1_col7, #T_4edf2_row2_col4, #T_4edf2_row2_col5, #T_4edf2_row2_col6, #T_4edf2_row2_col7, #T_4edf2_row3_col4, #T_4edf2_row3_col5, #T_4edf2_row3_col6, #T_4edf2_row3_col7, #T_4edf2_row4_col4, #T_4edf2_row4_col5, #T_4edf2_row4_col6, #T_4edf2_row4_col7, #T_4edf2_row5_col4, #T_4edf2_row5_col5, #T_4edf2_row5_col6, #T_4edf2_row5_col7, #T_4edf2_row6_col4, #T_4edf2_row6_col5, #T_4edf2_row6_col6, #T_4edf2_row6_col7, #T_4edf2_row7_col4, #T_4edf2_row7_col5, #T_4edf2_row7_col6, #T_4edf2_row7_col7 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_4edf2\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_4edf2_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_4edf2_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_4edf2_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_4edf2_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_4edf2_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_4edf2_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_4edf2_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_4edf2_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_4edf2_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_4edf2_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_4edf2_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_4edf2_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_4edf2_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_4edf2_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_4edf2_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_4edf2_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_4edf2_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4edf2_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_4edf2_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_4edf2_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_4edf2_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_4edf2_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_4edf2_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_4edf2_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_4edf2_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_4edf2_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4edf2_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_4edf2_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_4edf2_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_4edf2_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_4edf2_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_4edf2_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" <td id=\"T_4edf2_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
" <td id=\"T_4edf2_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
" <td id=\"T_4edf2_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4edf2_level0_row3\" class=\"row_heading level0 row3\" >knn</th>\n",
" <td id=\"T_4edf2_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_4edf2_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_4edf2_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_4edf2_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_4edf2_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" <td id=\"T_4edf2_row3_col5\" class=\"data row3 col5\" >1.000000</td>\n",
" <td id=\"T_4edf2_row3_col6\" class=\"data row3 col6\" >1.000000</td>\n",
" <td id=\"T_4edf2_row3_col7\" class=\"data row3 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4edf2_level0_row4\" class=\"row_heading level0 row4\" >naive_bayes</th>\n",
" <td id=\"T_4edf2_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_4edf2_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_4edf2_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_4edf2_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_4edf2_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
" <td id=\"T_4edf2_row4_col5\" class=\"data row4 col5\" >1.000000</td>\n",
" <td id=\"T_4edf2_row4_col6\" class=\"data row4 col6\" >1.000000</td>\n",
" <td id=\"T_4edf2_row4_col7\" class=\"data row4 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4edf2_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
" <td id=\"T_4edf2_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
" <td id=\"T_4edf2_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
" <td id=\"T_4edf2_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
" <td id=\"T_4edf2_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
" <td id=\"T_4edf2_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
" <td id=\"T_4edf2_row5_col5\" class=\"data row5 col5\" >1.000000</td>\n",
" <td id=\"T_4edf2_row5_col6\" class=\"data row5 col6\" >1.000000</td>\n",
" <td id=\"T_4edf2_row5_col7\" class=\"data row5 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4edf2_level0_row6\" class=\"row_heading level0 row6\" >random_forest</th>\n",
" <td id=\"T_4edf2_row6_col0\" class=\"data row6 col0\" >1.000000</td>\n",
" <td id=\"T_4edf2_row6_col1\" class=\"data row6 col1\" >1.000000</td>\n",
" <td id=\"T_4edf2_row6_col2\" class=\"data row6 col2\" >1.000000</td>\n",
" <td id=\"T_4edf2_row6_col3\" class=\"data row6 col3\" >1.000000</td>\n",
" <td id=\"T_4edf2_row6_col4\" class=\"data row6 col4\" >1.000000</td>\n",
" <td id=\"T_4edf2_row6_col5\" class=\"data row6 col5\" >1.000000</td>\n",
" <td id=\"T_4edf2_row6_col6\" class=\"data row6 col6\" >1.000000</td>\n",
" <td id=\"T_4edf2_row6_col7\" class=\"data row6 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4edf2_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_4edf2_row7_col0\" class=\"data row7 col0\" >1.000000</td>\n",
" <td id=\"T_4edf2_row7_col1\" class=\"data row7 col1\" >1.000000</td>\n",
" <td id=\"T_4edf2_row7_col2\" class=\"data row7 col2\" >1.000000</td>\n",
" <td id=\"T_4edf2_row7_col3\" class=\"data row7 col3\" >1.000000</td>\n",
" <td id=\"T_4edf2_row7_col4\" class=\"data row7 col4\" >1.000000</td>\n",
" <td id=\"T_4edf2_row7_col5\" class=\"data row7 col5\" >1.000000</td>\n",
" <td id=\"T_4edf2_row7_col6\" class=\"data row7 col6\" >1.000000</td>\n",
" <td id=\"T_4edf2_row7_col7\" class=\"data row7 col7\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x26164c85940>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_02096_row0_col0, #T_02096_row0_col1, #T_02096_row1_col0, #T_02096_row1_col1, #T_02096_row2_col0, #T_02096_row2_col1, #T_02096_row3_col0, #T_02096_row3_col1, #T_02096_row4_col0, #T_02096_row4_col1, #T_02096_row5_col0, #T_02096_row5_col1, #T_02096_row6_col0, #T_02096_row6_col1, #T_02096_row7_col0, #T_02096_row7_col1 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_02096_row0_col2, #T_02096_row0_col3, #T_02096_row0_col4, #T_02096_row1_col2, #T_02096_row1_col3, #T_02096_row1_col4, #T_02096_row2_col2, #T_02096_row2_col3, #T_02096_row2_col4, #T_02096_row3_col2, #T_02096_row3_col3, #T_02096_row3_col4, #T_02096_row4_col2, #T_02096_row4_col3, #T_02096_row4_col4, #T_02096_row5_col2, #T_02096_row5_col3, #T_02096_row5_col4, #T_02096_row6_col2, #T_02096_row6_col3, #T_02096_row6_col4, #T_02096_row7_col2, #T_02096_row7_col3, #T_02096_row7_col4 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_02096\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_02096_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_02096_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_02096_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_02096_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_02096_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_02096_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_02096_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_02096_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_02096_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_02096_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_02096_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_02096_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_02096_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_02096_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_02096_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_02096_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_02096_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_02096_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_02096_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_02096_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_02096_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_02096_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_02096_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_02096_level0_row3\" class=\"row_heading level0 row3\" >knn</th>\n",
" <td id=\"T_02096_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_02096_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_02096_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_02096_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_02096_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_02096_level0_row4\" class=\"row_heading level0 row4\" >naive_bayes</th>\n",
" <td id=\"T_02096_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_02096_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_02096_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_02096_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_02096_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_02096_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
" <td id=\"T_02096_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
" <td id=\"T_02096_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
" <td id=\"T_02096_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
" <td id=\"T_02096_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
" <td id=\"T_02096_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_02096_level0_row6\" class=\"row_heading level0 row6\" >random_forest</th>\n",
" <td id=\"T_02096_row6_col0\" class=\"data row6 col0\" >1.000000</td>\n",
" <td id=\"T_02096_row6_col1\" class=\"data row6 col1\" >1.000000</td>\n",
" <td id=\"T_02096_row6_col2\" class=\"data row6 col2\" >1.000000</td>\n",
" <td id=\"T_02096_row6_col3\" class=\"data row6 col3\" >1.000000</td>\n",
" <td id=\"T_02096_row6_col4\" class=\"data row6 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_02096_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_02096_row7_col0\" class=\"data row7 col0\" >1.000000</td>\n",
" <td id=\"T_02096_row7_col1\" class=\"data row7 col1\" >1.000000</td>\n",
" <td id=\"T_02096_row7_col2\" class=\"data row7 col2\" >1.000000</td>\n",
" <td id=\"T_02096_row7_col3\" class=\"data row7 col3\" >1.000000</td>\n",
" <td id=\"T_02096_row7_col4\" class=\"data row7 col4\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x26164c17020>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\1\\Desktop\\улгту\\3 курс\\МИИ\\mai\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
{
"ename": "KeyError",
"evalue": "'Survived'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyError\u001b[0m Traceback (most recent call last)",
"File \u001b[1;32mc:\\Users\\1\\Desktop\\улгту\\3 курс\\МИИ\\mai\\.venv\\Lib\\site-packages\\pandas\\core\\indexes\\base.py:3805\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[1;34m(self, key)\u001b[0m\n\u001b[0;32m 3804\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m-> 3805\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcasted_key\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 3806\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
"File \u001b[1;32mindex.pyx:167\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[1;34m()\u001b[0m\n",
"File \u001b[1;32mindex.pyx:196\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[1;34m()\u001b[0m\n",
"File \u001b[1;32mpandas\\\\_libs\\\\hashtable_class_helper.pxi:7081\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[1;34m()\u001b[0m\n",
"File \u001b[1;32mpandas\\\\_libs\\\\hashtable_class_helper.pxi:7089\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[1;34m()\u001b[0m\n",
"\u001b[1;31mKeyError\u001b[0m: 'Survived'",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[1;31mKeyError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[12], line 9\u001b[0m\n\u001b[0;32m 2\u001b[0m preprocessed_df \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mDataFrame(\n\u001b[0;32m 3\u001b[0m preprocessing_result,\n\u001b[0;32m 4\u001b[0m columns\u001b[38;5;241m=\u001b[39mpipeline_end\u001b[38;5;241m.\u001b[39mget_feature_names_out(),\n\u001b[0;32m 5\u001b[0m )\n\u001b[0;32m 7\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m class_models[best_model][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpreds\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m----> 9\u001b[0m error_index \u001b[38;5;241m=\u001b[39m y_test[\u001b[43my_test\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mSurvived\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;241m!=\u001b[39m y_pred]\u001b[38;5;241m.\u001b[39mindex\u001b[38;5;241m.\u001b[39mtolist()\n\u001b[0;32m 10\u001b[0m display(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mError items count: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(error_index)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 12\u001b[0m error_predicted \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mSeries(y_pred, index\u001b[38;5;241m=\u001b[39my_test\u001b[38;5;241m.\u001b[39mindex)\u001b[38;5;241m.\u001b[39mloc[error_index]\n",
"File \u001b[1;32mc:\\Users\\1\\Desktop\\улгту\\3 курс\\МИИ\\mai\\.venv\\Lib\\site-packages\\pandas\\core\\frame.py:4102\u001b[0m, in \u001b[0;36mDataFrame.__getitem__\u001b[1;34m(self, key)\u001b[0m\n\u001b[0;32m 4100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcolumns\u001b[38;5;241m.\u001b[39mnlevels \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m 4101\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_getitem_multilevel(key)\n\u001b[1;32m-> 4102\u001b[0m indexer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcolumns\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 4103\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_integer(indexer):\n\u001b[0;32m 4104\u001b[0m indexer \u001b[38;5;241m=\u001b[39m [indexer]\n",
"File \u001b[1;32mc:\\Users\\1\\Desktop\\улгту\\3 курс\\МИИ\\mai\\.venv\\Lib\\site-packages\\pandas\\core\\indexes\\base.py:3812\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[1;34m(self, key)\u001b[0m\n\u001b[0;32m 3807\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(casted_key, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[0;32m 3808\u001b[0m \u001b[38;5;28misinstance\u001b[39m(casted_key, abc\u001b[38;5;241m.\u001b[39mIterable)\n\u001b[0;32m 3809\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28many\u001b[39m(\u001b[38;5;28misinstance\u001b[39m(x, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m casted_key)\n\u001b[0;32m 3810\u001b[0m ):\n\u001b[0;32m 3811\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m InvalidIndexError(key)\n\u001b[1;32m-> 3812\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01merr\u001b[39;00m\n\u001b[0;32m 3813\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[0;32m 3814\u001b[0m \u001b[38;5;66;03m# If we have a listlike key, _check_indexing_error will raise\u001b[39;00m\n\u001b[0;32m 3815\u001b[0m \u001b[38;5;66;03m# InvalidIndexError. Otherwise we fall through and re-raise\u001b[39;00m\n\u001b[0;32m 3816\u001b[0m \u001b[38;5;66;03m# the TypeError.\u001b[39;00m\n\u001b[0;32m 3817\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_indexing_error(key)\n",
"\u001b[1;31mKeyError\u001b[0m: 'Survived'"
]
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"Survived\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Пример использования обученной модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 450\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Подбор гиперпараметров методом поиска по сетке\n",
"\n",
"https://www.kaggle.com/code/sociopath00/random-forest-using-gridsearchcv\n",
"\n",
"https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"optimized_model_type = \"random_forest\"\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
" \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10],\n",
" \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"optimized_model = ensemble.RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=7,\n",
" max_features=\"sqrt\",\n",
" n_estimators=30,\n",
")\n",
"\n",
"result = {}\n",
"\n",
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'optimized_metrics' is not defined",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[14], line 4\u001b[0m\n\u001b[0;32m 1\u001b[0m _, ax \u001b[38;5;241m=\u001b[39m plt\u001b[38;5;241m.\u001b[39msubplots(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m, figsize\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m10\u001b[39m, \u001b[38;5;241m4\u001b[39m), sharex\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, sharey\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m 2\u001b[0m )\n\u001b[1;32m----> 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m index \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;28mlen\u001b[39m(\u001b[43moptimized_metrics\u001b[49m)):\n\u001b[0;32m 5\u001b[0m c_matrix \u001b[38;5;241m=\u001b[39m optimized_metrics\u001b[38;5;241m.\u001b[39miloc[index][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mConfusion_matrix\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m 6\u001b[0m disp \u001b[38;5;241m=\u001b[39m ConfusionMatrixDisplay(\n\u001b[0;32m 7\u001b[0m confusion_matrix\u001b[38;5;241m=\u001b[39mc_matrix, display_labels\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDied\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSirvived\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m 8\u001b[0m )\u001b[38;5;241m.\u001b[39mplot(ax\u001b[38;5;241m=\u001b[39max\u001b[38;5;241m.\u001b[39mflat[index])\n",
"\u001b[1;31mNameError\u001b[0m: name 'optimized_metrics' is not defined"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x400 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Died\", \"Sirvived\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}