AIM-PIbd-31-Razubaev-S-M/Lab4/lab4.ipynb

7035 lines
509 KiB
Plaintext
Raw Normal View History

2024-11-08 22:14:23 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Лабораторная 4"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Информация о диабете индейцев Пима"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['Pregnancies', 'Glucose', 'BloodPressure', 'SkinThickness', 'Insulin',\n",
" 'BMI', 'DiabetesPedigreeFunction', 'Age', 'Outcome'],\n",
" dtype='object')\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>6</td>\n",
" <td>148</td>\n",
" <td>72</td>\n",
" <td>35</td>\n",
" <td>0</td>\n",
" <td>33.6</td>\n",
" <td>0.627</td>\n",
" <td>50</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>85</td>\n",
" <td>66</td>\n",
" <td>29</td>\n",
" <td>0</td>\n",
" <td>26.6</td>\n",
" <td>0.351</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>8</td>\n",
" <td>183</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>23.3</td>\n",
" <td>0.672</td>\n",
" <td>32</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>89</td>\n",
" <td>66</td>\n",
" <td>23</td>\n",
" <td>94</td>\n",
" <td>28.1</td>\n",
" <td>0.167</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>137</td>\n",
" <td>40</td>\n",
" <td>35</td>\n",
" <td>168</td>\n",
" <td>43.1</td>\n",
" <td>2.288</td>\n",
" <td>33</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>763</th>\n",
" <td>10</td>\n",
" <td>101</td>\n",
" <td>76</td>\n",
" <td>48</td>\n",
" <td>180</td>\n",
" <td>32.9</td>\n",
" <td>0.171</td>\n",
" <td>63</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>764</th>\n",
" <td>2</td>\n",
" <td>122</td>\n",
" <td>70</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" <td>36.8</td>\n",
" <td>0.340</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>765</th>\n",
" <td>5</td>\n",
" <td>121</td>\n",
" <td>72</td>\n",
" <td>23</td>\n",
" <td>112</td>\n",
" <td>26.2</td>\n",
" <td>0.245</td>\n",
" <td>30</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>766</th>\n",
" <td>1</td>\n",
" <td>126</td>\n",
" <td>60</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30.1</td>\n",
" <td>0.349</td>\n",
" <td>47</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>767</th>\n",
" <td>1</td>\n",
" <td>93</td>\n",
" <td>70</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" <td>30.4</td>\n",
" <td>0.315</td>\n",
" <td>23</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>768 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"0 6 148 72 35 0 33.6 \n",
"1 1 85 66 29 0 26.6 \n",
"2 8 183 64 0 0 23.3 \n",
"3 1 89 66 23 94 28.1 \n",
"4 0 137 40 35 168 43.1 \n",
".. ... ... ... ... ... ... \n",
"763 10 101 76 48 180 32.9 \n",
"764 2 122 70 27 0 36.8 \n",
"765 5 121 72 23 112 26.2 \n",
"766 1 126 60 0 0 30.1 \n",
"767 1 93 70 31 0 30.4 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"0 0.627 50 1 \n",
"1 0.351 31 0 \n",
"2 0.672 32 1 \n",
"3 0.167 21 0 \n",
"4 2.288 33 1 \n",
".. ... ... ... \n",
"763 0.171 63 0 \n",
"764 0.340 27 0 \n",
"765 0.245 30 0 \n",
"766 0.349 47 1 \n",
"767 0.315 23 0 \n",
"\n",
"[768 rows x 9 columns]"
]
},
"execution_count": 92,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"df = pd.read_csv(\".//scv//diabetes.csv\")\n",
"print(df.columns)\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование выборок"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>196</th>\n",
" <td>1</td>\n",
" <td>105</td>\n",
" <td>58</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>24.3</td>\n",
" <td>0.187</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69</th>\n",
" <td>4</td>\n",
" <td>146</td>\n",
" <td>85</td>\n",
" <td>27</td>\n",
" <td>100</td>\n",
" <td>28.9</td>\n",
" <td>0.189</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>494</th>\n",
" <td>3</td>\n",
" <td>80</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.174</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>463</th>\n",
" <td>5</td>\n",
" <td>88</td>\n",
" <td>78</td>\n",
" <td>30</td>\n",
" <td>0</td>\n",
" <td>27.6</td>\n",
" <td>0.258</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>653</th>\n",
" <td>2</td>\n",
" <td>120</td>\n",
" <td>54</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>26.8</td>\n",
" <td>0.455</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>322</th>\n",
" <td>0</td>\n",
" <td>124</td>\n",
" <td>70</td>\n",
" <td>20</td>\n",
" <td>0</td>\n",
" <td>27.4</td>\n",
" <td>0.254</td>\n",
" <td>36</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>109</th>\n",
" <td>0</td>\n",
" <td>95</td>\n",
" <td>85</td>\n",
" <td>25</td>\n",
" <td>36</td>\n",
" <td>37.4</td>\n",
" <td>0.247</td>\n",
" <td>24</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>1</td>\n",
" <td>97</td>\n",
" <td>66</td>\n",
" <td>15</td>\n",
" <td>140</td>\n",
" <td>23.2</td>\n",
" <td>0.487</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>651</th>\n",
" <td>1</td>\n",
" <td>117</td>\n",
" <td>60</td>\n",
" <td>23</td>\n",
" <td>106</td>\n",
" <td>33.8</td>\n",
" <td>0.466</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>197</th>\n",
" <td>3</td>\n",
" <td>107</td>\n",
" <td>62</td>\n",
" <td>13</td>\n",
" <td>48</td>\n",
" <td>22.9</td>\n",
" <td>0.678</td>\n",
" <td>23</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"196 1 105 58 0 0 24.3 \n",
"69 4 146 85 27 100 28.9 \n",
"494 3 80 0 0 0 0.0 \n",
"463 5 88 78 30 0 27.6 \n",
"653 2 120 54 0 0 26.8 \n",
".. ... ... ... ... ... ... \n",
"322 0 124 70 20 0 27.4 \n",
"109 0 95 85 25 36 37.4 \n",
"27 1 97 66 15 140 23.2 \n",
"651 1 117 60 23 106 33.8 \n",
"197 3 107 62 13 48 22.9 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"196 0.187 21 0 \n",
"69 0.189 27 0 \n",
"494 0.174 22 0 \n",
"463 0.258 37 0 \n",
"653 0.455 27 0 \n",
".. ... ... ... \n",
"322 0.254 36 1 \n",
"109 0.247 24 1 \n",
"27 0.487 22 0 \n",
"651 0.466 27 0 \n",
"197 0.678 23 1 \n",
"\n",
"[614 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>196</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>494</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>463</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>653</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>322</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>109</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>651</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>197</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Outcome\n",
"196 0\n",
"69 0\n",
"494 0\n",
"463 0\n",
"653 0\n",
".. ...\n",
"322 1\n",
"109 1\n",
"27 0\n",
"651 0\n",
"197 1\n",
"\n",
"[614 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>669</th>\n",
" <td>9</td>\n",
" <td>154</td>\n",
" <td>78</td>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>30.9</td>\n",
" <td>0.164</td>\n",
" <td>45</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>379</th>\n",
" <td>0</td>\n",
" <td>93</td>\n",
" <td>100</td>\n",
" <td>39</td>\n",
" <td>72</td>\n",
" <td>43.4</td>\n",
" <td>1.021</td>\n",
" <td>35</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>640</th>\n",
" <td>0</td>\n",
" <td>102</td>\n",
" <td>86</td>\n",
" <td>17</td>\n",
" <td>105</td>\n",
" <td>29.3</td>\n",
" <td>0.695</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>658</th>\n",
" <td>11</td>\n",
" <td>127</td>\n",
" <td>106</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>39.0</td>\n",
" <td>0.190</td>\n",
" <td>51</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>304</th>\n",
" <td>3</td>\n",
" <td>150</td>\n",
" <td>76</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>21.0</td>\n",
" <td>0.207</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203</th>\n",
" <td>2</td>\n",
" <td>99</td>\n",
" <td>70</td>\n",
" <td>16</td>\n",
" <td>44</td>\n",
" <td>20.4</td>\n",
" <td>0.235</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>605</th>\n",
" <td>1</td>\n",
" <td>124</td>\n",
" <td>60</td>\n",
" <td>32</td>\n",
" <td>0</td>\n",
" <td>35.8</td>\n",
" <td>0.514</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>561</th>\n",
" <td>0</td>\n",
" <td>198</td>\n",
" <td>66</td>\n",
" <td>32</td>\n",
" <td>274</td>\n",
" <td>41.3</td>\n",
" <td>0.502</td>\n",
" <td>28</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>280</th>\n",
" <td>0</td>\n",
" <td>146</td>\n",
" <td>70</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>37.9</td>\n",
" <td>0.334</td>\n",
" <td>28</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>103</th>\n",
" <td>1</td>\n",
" <td>81</td>\n",
" <td>72</td>\n",
" <td>18</td>\n",
" <td>40</td>\n",
" <td>26.6</td>\n",
" <td>0.283</td>\n",
" <td>24</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>154 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"669 9 154 78 30 100 30.9 \n",
"379 0 93 100 39 72 43.4 \n",
"640 0 102 86 17 105 29.3 \n",
"658 11 127 106 0 0 39.0 \n",
"304 3 150 76 0 0 21.0 \n",
".. ... ... ... ... ... ... \n",
"203 2 99 70 16 44 20.4 \n",
"605 1 124 60 32 0 35.8 \n",
"561 0 198 66 32 274 41.3 \n",
"280 0 146 70 0 0 37.9 \n",
"103 1 81 72 18 40 26.6 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"669 0.164 45 0 \n",
"379 1.021 35 0 \n",
"640 0.695 27 0 \n",
"658 0.190 51 0 \n",
"304 0.207 37 0 \n",
".. ... ... ... \n",
"203 0.235 27 0 \n",
"605 0.514 21 0 \n",
"561 0.502 28 1 \n",
"280 0.334 28 1 \n",
"103 0.283 24 0 \n",
"\n",
"[154 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>669</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>379</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>640</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>658</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>304</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>605</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>561</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>280</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>103</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>154 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Outcome\n",
"669 0\n",
"379 0\n",
"640 0\n",
"658 0\n",
"304 0\n",
".. ...\n",
"203 0\n",
"605 0\n",
"561 1\n",
"280 1\n",
"103 0\n",
"\n",
"[154 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" X = df_input # Contains all columns.\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
" # Split original dataframe into train and temp dataframes.\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
" # Split the temp dataframe into val and test dataframes.\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"Outcome\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=9\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Классификация данных"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"from transformers import DiabetFeatures\n",
"\n",
"\n",
"columns_to_drop = [\"Glucose\", \"Age\", \"BloodPressure\", \"Outcome\", \"DiabetesPedigreeFunction\"]\n",
"num_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop and df[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop and df[column].dtype == \"object\"\n",
"]\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Проверка работы конвеера"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>196</th>\n",
" <td>-0.838489</td>\n",
" <td>-1.297466</td>\n",
" <td>-0.688684</td>\n",
" <td>-0.946400</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69</th>\n",
" <td>0.072181</td>\n",
" <td>0.395520</td>\n",
" <td>0.180416</td>\n",
" <td>-0.377190</td>\n",
" </tr>\n",
" <tr>\n",
" <th>494</th>\n",
" <td>-0.231376</td>\n",
" <td>-1.297466</td>\n",
" <td>-0.688684</td>\n",
" <td>-3.953317</td>\n",
" </tr>\n",
" <tr>\n",
" <th>463</th>\n",
" <td>0.375738</td>\n",
" <td>0.583630</td>\n",
" <td>-0.688684</td>\n",
" <td>-0.538054</td>\n",
" </tr>\n",
" <tr>\n",
" <th>653</th>\n",
" <td>-0.534932</td>\n",
" <td>-1.297466</td>\n",
" <td>-0.688684</td>\n",
" <td>-0.637047</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>322</th>\n",
" <td>-1.142046</td>\n",
" <td>-0.043402</td>\n",
" <td>-0.688684</td>\n",
" <td>-0.562802</td>\n",
" </tr>\n",
" <tr>\n",
" <th>109</th>\n",
" <td>-1.142046</td>\n",
" <td>0.270114</td>\n",
" <td>-0.375808</td>\n",
" <td>0.674613</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>-0.838489</td>\n",
" <td>-0.356918</td>\n",
" <td>0.528056</td>\n",
" <td>-1.082516</td>\n",
" </tr>\n",
" <tr>\n",
" <th>651</th>\n",
" <td>-0.838489</td>\n",
" <td>0.144708</td>\n",
" <td>0.232562</td>\n",
" <td>0.229143</td>\n",
" </tr>\n",
" <tr>\n",
" <th>197</th>\n",
" <td>-0.231376</td>\n",
" <td>-0.482325</td>\n",
" <td>-0.271516</td>\n",
" <td>-1.119638</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 4 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies SkinThickness Insulin BMI\n",
"196 -0.838489 -1.297466 -0.688684 -0.946400\n",
"69 0.072181 0.395520 0.180416 -0.377190\n",
"494 -0.231376 -1.297466 -0.688684 -3.953317\n",
"463 0.375738 0.583630 -0.688684 -0.538054\n",
"653 -0.534932 -1.297466 -0.688684 -0.637047\n",
".. ... ... ... ...\n",
"322 -1.142046 -0.043402 -0.688684 -0.562802\n",
"109 -1.142046 0.270114 -0.375808 0.674613\n",
"27 -0.838489 -0.356918 0.528056 -1.082516\n",
"651 -0.838489 0.144708 0.232562 0.229143\n",
"197 -0.231376 -0.482325 -0.271516 -1.119638\n",
"\n",
"[614 rows x 4 columns]"
]
},
"execution_count": 95,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование набора моделей для классификации"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=9)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=9\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=9,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение моделей на обучающем наборе данных и оценка на тестовом"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Сводная таблица оценок качества для использованных моделей классификации\n",
"\n",
"Матрица неточностей"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0kAAAQ9CAYAAACMbQYZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVhUZfsH8O+wDIMswyIwoIAg7op7imsSipZb8maavrmbC65ZaqWCu5Zp7mkGmppppbmkpqTm/rrnFm4oKIuGAoKyzZzfH/ycnGBwBgZmzvD9XNe5cp7zzJl7SM/NfZ7nPEciCIIAIiIiIiIiAgBYGDsAIiIiIiIiU8IiiYiIiIiI6CUskoiIiIiIiF7CIomIiIiIiOglLJKIiIiIiIhewiKJiIiIiIjoJSySiIiIiIiIXsIiiYiIiIiI6CUskoiIiIiIiF7CIomMLjo6GhKJBHfv3i2T49+9excSiQTR0dEGOd7hw4chkUhw+PBhgxyPiIjInEREREAikejUVyKRICIiomwDIioBFklEWqxcudJghRURERERiYeVsQMgKmu+vr54/vw5rK2t9XrfypUrUblyZQwcOFCjvV27dnj+/DmkUqkBoyQiIjIPn332GaZMmWLsMIhKhUUSmT2JRAKZTGaw41lYWBj0eEREROYiKysLdnZ2sLLir5gkbpxuRyZp5cqVqFevHmxsbODl5YXRo0cjLS2tUL8VK1bA398ftra2eO2113D06FG8/vrreP3119V9ironKTk5GYMGDULVqlVhY2MDT09P9OjRQ31fVLVq1XD16lUcOXIEEokEEolEfUxt9ySdPn0ab775JpydnWFnZ4fAwEB89dVXhv3BEBERmYgX9x5du3YN7733HpydndGmTZsi70nKycnBhAkT4ObmBgcHB3Tv3h33798v8riHDx9Gs2bNIJPJUL16dXz99dda73PauHEjmjZtCltbW7i4uKBPnz5ISEgok+9LFQvLfDI5ERERiIyMREhICEaOHInY2FisWrUKZ86cwfHjx9XT5latWoXw8HC0bdsWEyZMwN27d9GzZ084OzujatWqxX5GWFgYrl69ijFjxqBatWp4+PAhDhw4gPj4eFSrVg1LlizBmDFjYG9vj08//RQA4OHhofV4Bw4cQNeuXeHp6Ylx48ZBoVDg+vXr2L17N8aNG2e4Hw4REZGJeeedd1CjRg3MnTsXgiDg4cOHhfoMHToUGzduxHvvvYdWrVrh999/x1tvvVWo34ULF9C5c2d4enoiMjISSqUSM2fOhJubW6G+c+bMwbRp09C7d28MHToUjx49wrJly9CuXTtcuHABTk5OZfF1qaIQiIwsKipKACDExcUJDx8+FKRSqdCpUydBqVSq+yxfvlwAIHz77beCIAhCTk6O4OrqKjRv3lzIy8tT94uOjhYACO3bt1e3xcXFCQCEqKgoQRAE4cmTJwIA4fPPPy82rnr16mkc54VDhw4JAIRDhw4JgiAI+fn5gp+fn+Dr6ys8efJEo69KpdL9B0FERCQiM2bMEAAIffv2LbL9hYsXLwoAhFGjRmn0e++99wQAwowZM9Rt3bp1EypVqiQ8ePBA3Xbz5k3ByspK45h3794VLC0thTlz5mgc8/Lly4KVlVWhdiJ9cbodmZSDBw8iNzcX48ePh4XFP389hw0bBkdHR+zZswcAcPbsWaSmpmLYsGEa85779esHZ2fnYj/D1tYWUqkUhw8fxpMnT0od84ULFxAXF4fx48cXumql6xKoREREYjVixIhi9//6668AgLFjx2q0jx8/XuO1UqnEwYMH0bNnT3h5eanbAwIC0KVLF42+P//8M1QqFXr37o2///5bvSkUCtSoUQOHDh0qxTci4nQ7MjH37t0DANSqVUujXSqVwt/fX73/xX8DAgI0+llZWaFatWrFfoaNjQ0WLFiADz/8EB4eHmjZsiW6du2K999/HwqFQu+Yb9++DQCoX7++3u8lIiISOz8/v2L337t3DxYWFqhevbpG+79z/cOHD/H8+fNCuR0onO9v3rwJQRBQo0aNIj9T3xVtif6NRRJVSOPHj0e3bt2wY8cO7N+/H9OmTcO8efPw+++/o3HjxsYOj4iISDRsbW3L/TNVKhUkEgn27t0LS0vLQvvt7e3LPSYyL5xuRybF19cXABAbG6vRnpubi7i4OPX+F/+9deuWRr/8/Hz1CnWvUr16dXz44Yf47bffcOXKFeTm5mLRokXq/bpOlXtxZezKlSs69SciIqpIfH19oVKp1DMvXvh3rnd3d4dMJiuU24HC+b569eoQBAF+fn4ICQkptLVs2dLwX4QqFBZJZFJCQkIglUqxdOlSCIKgbl+3bh3S09PVK+E0a9YMrq6uWLt2LfLz89X9Nm3a9Mr7jJ49e4bs7GyNturVq8PBwQE5OTnqNjs7uyKXHf+3Jk2awM/PD0uWLCnU/+XvQEREVBG9uJ9o6dKlGu1LlizReG1paYmQkBDs2LEDiYmJ6vZbt25h7969Gn179eoFS0tLREZGFsq1giAgNTXVgN+AKiJOtyOT4ubmhqlTpyIyMhKdO3dG9+7dERsbi5UrV6J58+bo378/gIJ7lCIiIjBmzBgEBwejd+/euHv3LqKjo1G9evViR4Fu3LiBN954A71790bdunVhZWWF7du3IyUlBX369FH3a9q0KVatWoXZs2cjICAA7u7uCA4OLnQ8CwsLrFq1Ct26dUOjRo0waNAgeHp64q+//sLVq1exf/9+w/+giIiIRKJRo0bo27cvVq5cifT0dLRq1QoxMTFFjhhFRETgt99+Q+vWrTFy5EgolUosX74c9evXx8WLF9X9qlevjtmzZ2Pq1KnqR4A4ODggLi4O27dvx/DhwzFp0qRy/JZkblgkkcmJiIiAm5sbli9fjgkTJsDFxQXDhw/H3LlzNW7EDA8PhyAIWLRoESZNmoSGDRti586dGDt2LGQymdbje3t7o2/fvoiJicF3330HKysr1K5dG1u3bkVYWJi63/Tp03Hv3j0sXLgQT58+Rfv27YsskgAgNDQUhw4dQmRkJBYtWgSVSoXq1atj2LBhhvvBEBERidS3334LNzc3bNq0CTt27EBwcDD27NkDb29vjX5NmzbF3r17MWnSJEybNg3e3t6YOXMmrl+/jr/++kuj75QpU1CzZk0sXrwYkZGRAApyfKdOndC9e/dy+25kniQC5wORGVGpVHBzc0OvXr2wdu1aY4dDREREBtCzZ09cvXoVN2/eNHYoVEHwniQSrezs7ELzkDds2IDHjx/j9ddfN05QREREVCrPnz/XeH3z5k38+uuvzO1UrjiSRKJ1+PBhTJgwAe+88w5cXV1x/vx5rFu3DnXq1MG5c+cglUqNHSIRERHpydPTEwMHDlQ/H3HVqlXIycnBhQsXtD4XicjQeE8SiVa1atXg7e2NpUuX4vHjx3BxccH777+P+fPns0AiIiISqc6dO+P7779HcnIybGxsEBQUhLlz57JAonLFkSQiIiIiIqKX8J4kIiIiIiKil7BIIiIiIiIiegnvSSpnKpUKiYmJcHBwKPaBp0TmSBAEPH36FF5eXrCwMPw1muzsbOTm5hbbRyqVFvscrZcplUpERERg48aNSE5OhpeXFwYOHIjPPvtM/e9XEATMmDEDa9euRVpaGlq3bo1Vq1Zx7jyRiDA3U0VXlvnZ0Lm5vLBIKmeJiYmFHpxGVNEkJCSgatWqBj1mdnY2/HztkfxQWWw/hUKBuLg4nU7GCxYswKpVq7B+/XrUq1cPZ8+exaBBgyCXyzF27FgAwMKFC7F06VKsX78efn5+mDZtGkJDQ3Ht2jWTO+ETUdGYm4kKGDo/l0VuLi9cuKGcpaenw8nJCffOV4OjPWc7GsPbNRsYO4QKKx95OIZfkZaWBrlcbtBjZ2RkQC6X49ZZbzg6FP1vK+OpCgHNEpCeng5HR8dXHrNr167w8PDAunXr1G1hYWGwtbXFxo0bIQgCvLy88OGHH2LSpEkACv6Ne3h4IDo6Gn369DHMlyOiMsXcbHyvfT3U2CFUaMqcbNxZPtPg+bkscnN54UhSOXsxjO9ob6H
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Healthy\", \"Sick\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_19144_row0_col0 {\n",
" background-color: #1f9e89;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row0_col1, #T_19144_row3_col0, #T_19144_row3_col2, #T_19144_row7_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_19144_row0_col2, #T_19144_row4_col3, #T_19144_row5_col3, #T_19144_row7_col0, #T_19144_row7_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row0_col3 {\n",
" background-color: #25848e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row0_col4 {\n",
" background-color: #8707a6;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row0_col5, #T_19144_row1_col7, #T_19144_row3_col4, #T_19144_row3_col6 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row0_col6, #T_19144_row5_col7, #T_19144_row7_col4, #T_19144_row7_col5 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row0_col7 {\n",
" background-color: #7a02a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row1_col0 {\n",
" background-color: #21918c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row1_col1 {\n",
" background-color: #4ac16d;\n",
" color: #000000;\n",
"}\n",
"#T_19144_row1_col2 {\n",
" background-color: #27ad81;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row1_col3 {\n",
" background-color: #69cd5b;\n",
" color: #000000;\n",
"}\n",
"#T_19144_row1_col4 {\n",
" background-color: #7d03a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row1_col5 {\n",
" background-color: #d6556d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row1_col6, #T_19144_row7_col6 {\n",
" background-color: #7801a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row2_col0 {\n",
" background-color: #2db27d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row2_col1 {\n",
" background-color: #48c16e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row2_col2, #T_19144_row6_col3 {\n",
" background-color: #1e9c89;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row2_col3 {\n",
" background-color: #1f988b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row2_col4 {\n",
" background-color: #9d189d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row2_col5 {\n",
" background-color: #cc4778;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row2_col6 {\n",
" background-color: #8004a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row2_col7 {\n",
" background-color: #920fa3;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row3_col1 {\n",
" background-color: #3aba76;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row3_col3 {\n",
" background-color: #20a386;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row3_col5 {\n",
" background-color: #c8437b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row3_col7, #T_19144_row6_col5 {\n",
" background-color: #a11b9b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row4_col0 {\n",
" background-color: #23a983;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row4_col1 {\n",
" background-color: #32b67a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row4_col2 {\n",
" background-color: #26828e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row4_col4 {\n",
" background-color: #8e0ca4;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row4_col5 {\n",
" background-color: #c03a83;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row4_col6 {\n",
" background-color: #5601a4;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row4_col7 {\n",
" background-color: #5102a3;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row5_col0 {\n",
" background-color: #95d840;\n",
" color: #000000;\n",
"}\n",
"#T_19144_row5_col1 {\n",
" background-color: #29af7f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row5_col2 {\n",
" background-color: #5cc863;\n",
" color: #000000;\n",
"}\n",
"#T_19144_row5_col4 {\n",
" background-color: #ca457a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row5_col5 {\n",
" background-color: #bc3587;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row5_col6 {\n",
" background-color: #c6417d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row6_col0 {\n",
" background-color: #3bbb75;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row6_col1 {\n",
" background-color: #1f958b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row6_col2 {\n",
" background-color: #52c569;\n",
" color: #000000;\n",
"}\n",
"#T_19144_row6_col4 {\n",
" background-color: #b32c8e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row6_col6 {\n",
" background-color: #ad2793;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row6_col7 {\n",
" background-color: #7401a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_19144_row7_col2 {\n",
" background-color: #65cb5e;\n",
" color: #000000;\n",
"}\n",
"#T_19144_row7_col7 {\n",
" background-color: #b12a90;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_19144\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_19144_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_19144_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_19144_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_19144_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_19144_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_19144_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_19144_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_19144_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_19144_level0_row0\" class=\"row_heading level0 row0\" >naive_bayes</th>\n",
" <td id=\"T_19144_row0_col0\" class=\"data row0 col0\" >0.564516</td>\n",
" <td id=\"T_19144_row0_col1\" class=\"data row0 col1\" >0.628571</td>\n",
" <td id=\"T_19144_row0_col2\" class=\"data row0 col2\" >0.327103</td>\n",
" <td id=\"T_19144_row0_col3\" class=\"data row0 col3\" >0.407407</td>\n",
" <td id=\"T_19144_row0_col4\" class=\"data row0 col4\" >0.677524</td>\n",
" <td id=\"T_19144_row0_col5\" class=\"data row0 col5\" >0.707792</td>\n",
" <td id=\"T_19144_row0_col6\" class=\"data row0 col6\" >0.414201</td>\n",
" <td id=\"T_19144_row0_col7\" class=\"data row0 col7\" >0.494382</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_19144_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_19144_row1_col0\" class=\"data row1 col0\" >0.494382</td>\n",
" <td id=\"T_19144_row1_col1\" class=\"data row1 col1\" >0.552632</td>\n",
" <td id=\"T_19144_row1_col2\" class=\"data row1 col2\" >0.616822</td>\n",
" <td id=\"T_19144_row1_col3\" class=\"data row1 col3\" >0.777778</td>\n",
" <td id=\"T_19144_row1_col4\" class=\"data row1 col4\" >0.646580</td>\n",
" <td id=\"T_19144_row1_col5\" class=\"data row1 col5\" >0.701299</td>\n",
" <td id=\"T_19144_row1_col6\" class=\"data row1 col6\" >0.548857</td>\n",
" <td id=\"T_19144_row1_col7\" class=\"data row1 col7\" >0.646154</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_19144_level0_row2\" class=\"row_heading level0 row2\" >knn</th>\n",
" <td id=\"T_19144_row2_col0\" class=\"data row2 col0\" >0.670807</td>\n",
" <td id=\"T_19144_row2_col1\" class=\"data row2 col1\" >0.551020</td>\n",
" <td id=\"T_19144_row2_col2\" class=\"data row2 col2\" >0.504673</td>\n",
" <td id=\"T_19144_row2_col3\" class=\"data row2 col3\" >0.500000</td>\n",
" <td id=\"T_19144_row2_col4\" class=\"data row2 col4\" >0.741042</td>\n",
" <td id=\"T_19144_row2_col5\" class=\"data row2 col5\" >0.681818</td>\n",
" <td id=\"T_19144_row2_col6\" class=\"data row2 col6\" >0.576000</td>\n",
" <td id=\"T_19144_row2_col7\" class=\"data row2 col7\" >0.524272</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_19144_level0_row3\" class=\"row_heading level0 row3\" >random_forest</th>\n",
" <td id=\"T_19144_row3_col0\" class=\"data row3 col0\" >0.955157</td>\n",
" <td id=\"T_19144_row3_col1\" class=\"data row3 col1\" >0.535714</td>\n",
" <td id=\"T_19144_row3_col2\" class=\"data row3 col2\" >0.995327</td>\n",
" <td id=\"T_19144_row3_col3\" class=\"data row3 col3\" >0.555556</td>\n",
" <td id=\"T_19144_row3_col4\" class=\"data row3 col4\" >0.982085</td>\n",
" <td id=\"T_19144_row3_col5\" class=\"data row3 col5\" >0.675325</td>\n",
" <td id=\"T_19144_row3_col6\" class=\"data row3 col6\" >0.974828</td>\n",
" <td id=\"T_19144_row3_col7\" class=\"data row3 col7\" >0.545455</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_19144_level0_row4\" class=\"row_heading level0 row4\" >logistic</th>\n",
" <td id=\"T_19144_row4_col0\" class=\"data row4 col0\" >0.618644</td>\n",
" <td id=\"T_19144_row4_col1\" class=\"data row4 col1\" >0.525000</td>\n",
" <td id=\"T_19144_row4_col2\" class=\"data row4 col2\" >0.341121</td>\n",
" <td id=\"T_19144_row4_col3\" class=\"data row4 col3\" >0.388889</td>\n",
" <td id=\"T_19144_row4_col4\" class=\"data row4 col4\" >0.697068</td>\n",
" <td id=\"T_19144_row4_col5\" class=\"data row4 col5\" >0.662338</td>\n",
" <td id=\"T_19144_row4_col6\" class=\"data row4 col6\" >0.439759</td>\n",
" <td id=\"T_19144_row4_col7\" class=\"data row4 col7\" >0.446809</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_19144_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
" <td id=\"T_19144_row5_col0\" class=\"data row5 col0\" >0.920213</td>\n",
" <td id=\"T_19144_row5_col1\" class=\"data row5 col1\" >0.512195</td>\n",
" <td id=\"T_19144_row5_col2\" class=\"data row5 col2\" >0.808411</td>\n",
" <td id=\"T_19144_row5_col3\" class=\"data row5 col3\" >0.388889</td>\n",
" <td id=\"T_19144_row5_col4\" class=\"data row5 col4\" >0.908795</td>\n",
" <td id=\"T_19144_row5_col5\" class=\"data row5 col5\" >0.655844</td>\n",
" <td id=\"T_19144_row5_col6\" class=\"data row5 col6\" >0.860697</td>\n",
" <td id=\"T_19144_row5_col7\" class=\"data row5 col7\" >0.442105</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_19144_level0_row6\" class=\"row_heading level0 row6\" >decision_tree</th>\n",
" <td id=\"T_19144_row6_col0\" class=\"data row6 col0\" >0.718615</td>\n",
" <td id=\"T_19144_row6_col1\" class=\"data row6 col1\" >0.459016</td>\n",
" <td id=\"T_19144_row6_col2\" class=\"data row6 col2\" >0.775701</td>\n",
" <td id=\"T_19144_row6_col3\" class=\"data row6 col3\" >0.518519</td>\n",
" <td id=\"T_19144_row6_col4\" class=\"data row6 col4\" >0.815961</td>\n",
" <td id=\"T_19144_row6_col5\" class=\"data row6 col5\" >0.616883</td>\n",
" <td id=\"T_19144_row6_col6\" class=\"data row6 col6\" >0.746067</td>\n",
" <td id=\"T_19144_row6_col7\" class=\"data row6 col7\" >0.486957</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_19144_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_19144_row7_col0\" class=\"data row7 col0\" >0.409195</td>\n",
" <td id=\"T_19144_row7_col1\" class=\"data row7 col1\" >0.417391</td>\n",
" <td id=\"T_19144_row7_col2\" class=\"data row7 col2\" >0.831776</td>\n",
" <td id=\"T_19144_row7_col3\" class=\"data row7 col3\" >0.888889</td>\n",
" <td id=\"T_19144_row7_col4\" class=\"data row7 col4\" >0.522801</td>\n",
" <td id=\"T_19144_row7_col5\" class=\"data row7 col5\" >0.525974</td>\n",
" <td id=\"T_19144_row7_col6\" class=\"data row7 col6\" >0.548536</td>\n",
" <td id=\"T_19144_row7_col7\" class=\"data row7 col7\" >0.568047</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x175bc15bbc0>"
]
},
"execution_count": 99,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_17a4c_row0_col0 {\n",
" background-color: #9dd93b;\n",
" color: #000000;\n",
"}\n",
"#T_17a4c_row0_col1, #T_17a4c_row2_col0 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_17a4c_row0_col2, #T_17a4c_row0_col3, #T_17a4c_row0_col4, #T_17a4c_row1_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row1_col0 {\n",
" background-color: #63cb5f;\n",
" color: #000000;\n",
"}\n",
"#T_17a4c_row1_col1 {\n",
" background-color: #26828e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row1_col3 {\n",
" background-color: #6a00a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row1_col4 {\n",
" background-color: #6600a7;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row2_col1 {\n",
" background-color: #1e9b8a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row2_col2 {\n",
" background-color: #d24f71;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row2_col3, #T_17a4c_row2_col4 {\n",
" background-color: #aa2395;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row3_col0 {\n",
" background-color: #7fd34e;\n",
" color: #000000;\n",
"}\n",
"#T_17a4c_row3_col1 {\n",
" background-color: #25ab82;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row3_col2 {\n",
" background-color: #cc4977;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row3_col3 {\n",
" background-color: #a01a9c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row3_col4 {\n",
" background-color: #9814a0;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row4_col0, #T_17a4c_row6_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row4_col1 {\n",
" background-color: #46c06f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row4_col2 {\n",
" background-color: #c13b82;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row4_col3, #T_17a4c_row7_col2, #T_17a4c_row7_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row4_col4 {\n",
" background-color: #7801a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row5_col0 {\n",
" background-color: #75d054;\n",
" color: #000000;\n",
"}\n",
"#T_17a4c_row5_col1 {\n",
" background-color: #31b57b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row5_col2 {\n",
" background-color: #b6308b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row5_col3 {\n",
" background-color: #a51f99;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row5_col4 {\n",
" background-color: #9c179e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row6_col0 {\n",
" background-color: #5ac864;\n",
" color: #000000;\n",
"}\n",
"#T_17a4c_row6_col2 {\n",
" background-color: #b22b8f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row6_col3 {\n",
" background-color: #6300a7;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row6_col4 {\n",
" background-color: #5c01a6;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row7_col0 {\n",
" background-color: #2fb47c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row7_col1 {\n",
" background-color: #1f978b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_17a4c_row7_col3 {\n",
" background-color: #5502a4;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_17a4c\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_17a4c_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_17a4c_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_17a4c_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_17a4c_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_17a4c_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_17a4c_level0_row0\" class=\"row_heading level0 row0\" >ridge</th>\n",
" <td id=\"T_17a4c_row0_col0\" class=\"data row0 col0\" >0.701299</td>\n",
" <td id=\"T_17a4c_row0_col1\" class=\"data row0 col1\" >0.646154</td>\n",
" <td id=\"T_17a4c_row0_col2\" class=\"data row0 col2\" >0.767037</td>\n",
" <td id=\"T_17a4c_row0_col3\" class=\"data row0 col3\" >0.400271</td>\n",
" <td id=\"T_17a4c_row0_col4\" class=\"data row0 col4\" >0.417827</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_17a4c_level0_row1\" class=\"row_heading level0 row1\" >logistic</th>\n",
" <td id=\"T_17a4c_row1_col0\" class=\"data row1 col0\" >0.662338</td>\n",
" <td id=\"T_17a4c_row1_col1\" class=\"data row1 col1\" >0.446809</td>\n",
" <td id=\"T_17a4c_row1_col2\" class=\"data row1 col2\" >0.766296</td>\n",
" <td id=\"T_17a4c_row1_col3\" class=\"data row1 col3\" >0.211501</td>\n",
" <td id=\"T_17a4c_row1_col4\" class=\"data row1 col4\" >0.216434</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_17a4c_level0_row2\" class=\"row_heading level0 row2\" >naive_bayes</th>\n",
" <td id=\"T_17a4c_row2_col0\" class=\"data row2 col0\" >0.707792</td>\n",
" <td id=\"T_17a4c_row2_col1\" class=\"data row2 col1\" >0.494382</td>\n",
" <td id=\"T_17a4c_row2_col2\" class=\"data row2 col2\" >0.753704</td>\n",
" <td id=\"T_17a4c_row2_col3\" class=\"data row2 col3\" >0.301834</td>\n",
" <td id=\"T_17a4c_row2_col4\" class=\"data row2 col4\" >0.315869</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_17a4c_level0_row3\" class=\"row_heading level0 row3\" >knn</th>\n",
" <td id=\"T_17a4c_row3_col0\" class=\"data row3 col0\" >0.681818</td>\n",
" <td id=\"T_17a4c_row3_col1\" class=\"data row3 col1\" >0.524272</td>\n",
" <td id=\"T_17a4c_row3_col2\" class=\"data row3 col2\" >0.745556</td>\n",
" <td id=\"T_17a4c_row3_col3\" class=\"data row3 col3\" >0.286093</td>\n",
" <td id=\"T_17a4c_row3_col4\" class=\"data row3 col4\" >0.286855</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_17a4c_level0_row4\" class=\"row_heading level0 row4\" >mlp</th>\n",
" <td id=\"T_17a4c_row4_col0\" class=\"data row4 col0\" >0.525974</td>\n",
" <td id=\"T_17a4c_row4_col1\" class=\"data row4 col1\" >0.568047</td>\n",
" <td id=\"T_17a4c_row4_col2\" class=\"data row4 col2\" >0.729074</td>\n",
" <td id=\"T_17a4c_row4_col3\" class=\"data row4 col3\" >0.173747</td>\n",
" <td id=\"T_17a4c_row4_col4\" class=\"data row4 col4\" >0.240181</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_17a4c_level0_row5\" class=\"row_heading level0 row5\" >random_forest</th>\n",
" <td id=\"T_17a4c_row5_col0\" class=\"data row5 col0\" >0.675325</td>\n",
" <td id=\"T_17a4c_row5_col1\" class=\"data row5 col1\" >0.545455</td>\n",
" <td id=\"T_17a4c_row5_col2\" class=\"data row5 col2\" >0.715093</td>\n",
" <td id=\"T_17a4c_row5_col3\" class=\"data row5 col3\" >0.293059</td>\n",
" <td id=\"T_17a4c_row5_col4\" class=\"data row5 col4\" >0.293176</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_17a4c_level0_row6\" class=\"row_heading level0 row6\" >gradient_boosting</th>\n",
" <td id=\"T_17a4c_row6_col0\" class=\"data row6 col0\" >0.655844</td>\n",
" <td id=\"T_17a4c_row6_col1\" class=\"data row6 col1\" >0.442105</td>\n",
" <td id=\"T_17a4c_row6_col2\" class=\"data row6 col2\" >0.709630</td>\n",
" <td id=\"T_17a4c_row6_col3\" class=\"data row6 col3\" >0.199961</td>\n",
" <td id=\"T_17a4c_row6_col4\" class=\"data row6 col4\" >0.203926</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_17a4c_level0_row7\" class=\"row_heading level0 row7\" >decision_tree</th>\n",
" <td id=\"T_17a4c_row7_col0\" class=\"data row7 col0\" >0.616883</td>\n",
" <td id=\"T_17a4c_row7_col1\" class=\"data row7 col1\" >0.486957</td>\n",
" <td id=\"T_17a4c_row7_col2\" class=\"data row7 col2\" >0.612870</td>\n",
" <td id=\"T_17a4c_row7_col3\" class=\"data row7 col3\" >0.183061</td>\n",
" <td id=\"T_17a4c_row7_col4\" class=\"data row7 col4\" >0.183927</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x175bec2b980>"
]
},
"execution_count": 100,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'ridge'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 46'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Predicted</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>30</th>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>109</td>\n",
" <td>75</td>\n",
" <td>26</td>\n",
" <td>0</td>\n",
" <td>36.0</td>\n",
" <td>0.546</td>\n",
" <td>60</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>82</th>\n",
" <td>7</td>\n",
" <td>1</td>\n",
" <td>83</td>\n",
" <td>78</td>\n",
" <td>26</td>\n",
" <td>71</td>\n",
" <td>29.3</td>\n",
" <td>0.767</td>\n",
" <td>36</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>86</th>\n",
" <td>13</td>\n",
" <td>1</td>\n",
" <td>106</td>\n",
" <td>72</td>\n",
" <td>54</td>\n",
" <td>0</td>\n",
" <td>36.6</td>\n",
" <td>0.178</td>\n",
" <td>45</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>91</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>123</td>\n",
" <td>80</td>\n",
" <td>15</td>\n",
" <td>176</td>\n",
" <td>32.0</td>\n",
" <td>0.443</td>\n",
" <td>34</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>95</th>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>144</td>\n",
" <td>72</td>\n",
" <td>27</td>\n",
" <td>228</td>\n",
" <td>33.9</td>\n",
" <td>0.255</td>\n",
" <td>40</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>176</th>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>85</td>\n",
" <td>78</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>31.2</td>\n",
" <td>0.382</td>\n",
" <td>42</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>201</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>138</td>\n",
" <td>82</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40.1</td>\n",
" <td>0.236</td>\n",
" <td>28</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>204</th>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>103</td>\n",
" <td>72</td>\n",
" <td>32</td>\n",
" <td>190</td>\n",
" <td>37.7</td>\n",
" <td>0.324</td>\n",
" <td>55</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>223</th>\n",
" <td>7</td>\n",
" <td>1</td>\n",
" <td>142</td>\n",
" <td>60</td>\n",
" <td>33</td>\n",
" <td>190</td>\n",
" <td>28.8</td>\n",
" <td>0.687</td>\n",
" <td>61</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>228</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>197</td>\n",
" <td>70</td>\n",
" <td>39</td>\n",
" <td>744</td>\n",
" <td>36.7</td>\n",
" <td>2.329</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>233</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>122</td>\n",
" <td>68</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>35.0</td>\n",
" <td>0.394</td>\n",
" <td>29</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>266</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>138</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>36.3</td>\n",
" <td>0.933</td>\n",
" <td>25</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>274</th>\n",
" <td>13</td>\n",
" <td>1</td>\n",
" <td>106</td>\n",
" <td>70</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>34.2</td>\n",
" <td>0.251</td>\n",
" <td>52</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>280</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>146</td>\n",
" <td>70</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>37.9</td>\n",
" <td>0.334</td>\n",
" <td>28</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>282</th>\n",
" <td>7</td>\n",
" <td>1</td>\n",
" <td>133</td>\n",
" <td>88</td>\n",
" <td>15</td>\n",
" <td>155</td>\n",
" <td>32.4</td>\n",
" <td>0.262</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>302</th>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>77</td>\n",
" <td>82</td>\n",
" <td>41</td>\n",
" <td>42</td>\n",
" <td>35.8</td>\n",
" <td>0.156</td>\n",
" <td>35</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>309</th>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>124</td>\n",
" <td>68</td>\n",
" <td>28</td>\n",
" <td>205</td>\n",
" <td>32.9</td>\n",
" <td>0.875</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>335</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>165</td>\n",
" <td>76</td>\n",
" <td>43</td>\n",
" <td>255</td>\n",
" <td>47.9</td>\n",
" <td>0.259</td>\n",
" <td>26</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>358</th>\n",
" <td>12</td>\n",
" <td>1</td>\n",
" <td>88</td>\n",
" <td>74</td>\n",
" <td>40</td>\n",
" <td>54</td>\n",
" <td>35.3</td>\n",
" <td>0.378</td>\n",
" <td>48</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>364</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>147</td>\n",
" <td>74</td>\n",
" <td>25</td>\n",
" <td>293</td>\n",
" <td>34.9</td>\n",
" <td>0.385</td>\n",
" <td>30</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>379</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>93</td>\n",
" <td>100</td>\n",
" <td>39</td>\n",
" <td>72</td>\n",
" <td>43.4</td>\n",
" <td>1.021</td>\n",
" <td>35</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>397</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>131</td>\n",
" <td>66</td>\n",
" <td>40</td>\n",
" <td>0</td>\n",
" <td>34.3</td>\n",
" <td>0.196</td>\n",
" <td>22</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>405</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>123</td>\n",
" <td>48</td>\n",
" <td>32</td>\n",
" <td>165</td>\n",
" <td>42.1</td>\n",
" <td>0.520</td>\n",
" <td>26</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>406</th>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>115</td>\n",
" <td>72</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>28.9</td>\n",
" <td>0.376</td>\n",
" <td>46</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>442</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>117</td>\n",
" <td>64</td>\n",
" <td>27</td>\n",
" <td>120</td>\n",
" <td>33.2</td>\n",
" <td>0.230</td>\n",
" <td>24</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>486</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>139</td>\n",
" <td>62</td>\n",
" <td>41</td>\n",
" <td>480</td>\n",
" <td>40.7</td>\n",
" <td>0.536</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>515</th>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>163</td>\n",
" <td>70</td>\n",
" <td>18</td>\n",
" <td>105</td>\n",
" <td>31.6</td>\n",
" <td>0.268</td>\n",
" <td>28</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>517</th>\n",
" <td>7</td>\n",
" <td>1</td>\n",
" <td>125</td>\n",
" <td>86</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>37.6</td>\n",
" <td>0.304</td>\n",
" <td>51</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>583</th>\n",
" <td>8</td>\n",
" <td>1</td>\n",
" <td>100</td>\n",
" <td>76</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>38.7</td>\n",
" <td>0.190</td>\n",
" <td>42</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>594</th>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>123</td>\n",
" <td>72</td>\n",
" <td>45</td>\n",
" <td>230</td>\n",
" <td>33.6</td>\n",
" <td>0.733</td>\n",
" <td>34</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>622</th>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>183</td>\n",
" <td>94</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40.8</td>\n",
" <td>1.461</td>\n",
" <td>45</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>630</th>\n",
" <td>7</td>\n",
" <td>0</td>\n",
" <td>114</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>27.4</td>\n",
" <td>0.732</td>\n",
" <td>34</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>634</th>\n",
" <td>10</td>\n",
" <td>1</td>\n",
" <td>92</td>\n",
" <td>62</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>25.9</td>\n",
" <td>0.167</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>646</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>167</td>\n",
" <td>74</td>\n",
" <td>17</td>\n",
" <td>144</td>\n",
" <td>23.4</td>\n",
" <td>0.447</td>\n",
" <td>33</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>658</th>\n",
" <td>11</td>\n",
" <td>1</td>\n",
" <td>127</td>\n",
" <td>106</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>39.0</td>\n",
" <td>0.190</td>\n",
" <td>51</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>669</th>\n",
" <td>9</td>\n",
" <td>1</td>\n",
" <td>154</td>\n",
" <td>78</td>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>30.9</td>\n",
" <td>0.164</td>\n",
" <td>45</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>674</th>\n",
" <td>8</td>\n",
" <td>1</td>\n",
" <td>91</td>\n",
" <td>82</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>35.6</td>\n",
" <td>0.587</td>\n",
" <td>68</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>676</th>\n",
" <td>9</td>\n",
" <td>0</td>\n",
" <td>156</td>\n",
" <td>86</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>24.8</td>\n",
" <td>0.230</td>\n",
" <td>53</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>682</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>95</td>\n",
" <td>64</td>\n",
" <td>39</td>\n",
" <td>105</td>\n",
" <td>44.6</td>\n",
" <td>0.366</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>699</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>118</td>\n",
" <td>70</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>44.5</td>\n",
" <td>0.904</td>\n",
" <td>26</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>702</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>168</td>\n",
" <td>88</td>\n",
" <td>29</td>\n",
" <td>0</td>\n",
" <td>35.0</td>\n",
" <td>0.905</td>\n",
" <td>52</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>723</th>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>117</td>\n",
" <td>86</td>\n",
" <td>30</td>\n",
" <td>105</td>\n",
" <td>39.1</td>\n",
" <td>0.251</td>\n",
" <td>42</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>725</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>112</td>\n",
" <td>78</td>\n",
" <td>40</td>\n",
" <td>0</td>\n",
" <td>39.4</td>\n",
" <td>0.236</td>\n",
" <td>38</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>730</th>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>130</td>\n",
" <td>78</td>\n",
" <td>23</td>\n",
" <td>79</td>\n",
" <td>28.4</td>\n",
" <td>0.323</td>\n",
" <td>34</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>744</th>\n",
" <td>13</td>\n",
" <td>1</td>\n",
" <td>153</td>\n",
" <td>88</td>\n",
" <td>37</td>\n",
" <td>140</td>\n",
" <td>40.6</td>\n",
" <td>1.174</td>\n",
" <td>39</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>750</th>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>136</td>\n",
" <td>70</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>31.2</td>\n",
" <td>1.182</td>\n",
" <td>22</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Predicted Glucose BloodPressure SkinThickness Insulin \\\n",
"30 5 1 109 75 26 0 \n",
"82 7 1 83 78 26 71 \n",
"86 13 1 106 72 54 0 \n",
"91 4 1 123 80 15 176 \n",
"95 6 1 144 72 27 228 \n",
"176 6 1 85 78 0 0 \n",
"201 1 1 138 82 0 0 \n",
"204 6 1 103 72 32 190 \n",
"223 7 1 142 60 33 190 \n",
"228 4 1 197 70 39 744 \n",
"233 4 1 122 68 0 0 \n",
"266 0 0 138 0 0 0 \n",
"274 13 1 106 70 0 0 \n",
"280 0 0 146 70 0 0 \n",
"282 7 1 133 88 15 155 \n",
"302 5 1 77 82 41 42 \n",
"309 2 0 124 68 28 205 \n",
"335 0 1 165 76 43 255 \n",
"358 12 1 88 74 40 54 \n",
"364 4 1 147 74 25 293 \n",
"379 0 1 93 100 39 72 \n",
"397 0 0 131 66 40 0 \n",
"405 2 1 123 48 32 165 \n",
"406 4 0 115 72 0 0 \n",
"442 4 1 117 64 27 120 \n",
"486 1 1 139 62 41 480 \n",
"515 3 0 163 70 18 105 \n",
"517 7 1 125 86 0 0 \n",
"583 8 1 100 76 0 0 \n",
"594 6 1 123 72 45 230 \n",
"622 6 1 183 94 0 0 \n",
"630 7 0 114 64 0 0 \n",
"634 10 1 92 62 0 0 \n",
"646 1 0 167 74 17 144 \n",
"658 11 1 127 106 0 0 \n",
"669 9 1 154 78 30 100 \n",
"674 8 1 91 82 0 0 \n",
"676 9 0 156 86 0 0 \n",
"682 0 1 95 64 39 105 \n",
"699 4 1 118 70 0 0 \n",
"702 1 0 168 88 29 0 \n",
"723 5 1 117 86 30 105 \n",
"725 4 1 112 78 40 0 \n",
"730 3 0 130 78 23 79 \n",
"744 13 1 153 88 37 140 \n",
"750 4 0 136 70 0 0 \n",
"\n",
" BMI DiabetesPedigreeFunction Age Outcome \n",
"30 36.0 0.546 60 0 \n",
"82 29.3 0.767 36 0 \n",
"86 36.6 0.178 45 0 \n",
"91 32.0 0.443 34 0 \n",
"95 33.9 0.255 40 0 \n",
"176 31.2 0.382 42 0 \n",
"201 40.1 0.236 28 0 \n",
"204 37.7 0.324 55 0 \n",
"223 28.8 0.687 61 0 \n",
"228 36.7 2.329 31 0 \n",
"233 35.0 0.394 29 0 \n",
"266 36.3 0.933 25 1 \n",
"274 34.2 0.251 52 0 \n",
"280 37.9 0.334 28 1 \n",
"282 32.4 0.262 37 0 \n",
"302 35.8 0.156 35 0 \n",
"309 32.9 0.875 30 1 \n",
"335 47.9 0.259 26 0 \n",
"358 35.3 0.378 48 0 \n",
"364 34.9 0.385 30 0 \n",
"379 43.4 1.021 35 0 \n",
"397 34.3 0.196 22 1 \n",
"405 42.1 0.520 26 0 \n",
"406 28.9 0.376 46 1 \n",
"442 33.2 0.230 24 0 \n",
"486 40.7 0.536 21 0 \n",
"515 31.6 0.268 28 1 \n",
"517 37.6 0.304 51 0 \n",
"583 38.7 0.190 42 0 \n",
"594 33.6 0.733 34 0 \n",
"622 40.8 1.461 45 0 \n",
"630 27.4 0.732 34 1 \n",
"634 25.9 0.167 31 0 \n",
"646 23.4 0.447 33 1 \n",
"658 39.0 0.190 51 0 \n",
"669 30.9 0.164 45 0 \n",
"674 35.6 0.587 68 0 \n",
"676 24.8 0.230 53 1 \n",
"682 44.6 0.366 22 0 \n",
"699 44.5 0.904 26 0 \n",
"702 35.0 0.905 52 1 \n",
"723 39.1 0.251 42 0 \n",
"725 39.4 0.236 38 0 \n",
"730 28.4 0.323 34 1 \n",
"744 40.6 1.174 39 0 \n",
"750 31.2 1.182 22 1 "
]
},
"execution_count": 102,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"Outcome\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Пример использования обученной модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>450</th>\n",
" <td>1.0</td>\n",
" <td>82.0</td>\n",
" <td>64.0</td>\n",
" <td>13.0</td>\n",
" <td>95.0</td>\n",
" <td>21.2</td>\n",
" <td>0.415</td>\n",
" <td>23.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"450 1.0 82.0 64.0 13.0 95.0 21.2 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"450 0.415 23.0 0.0 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>450</th>\n",
" <td>-0.838489</td>\n",
" <td>-0.482325</td>\n",
" <td>0.136961</td>\n",
" <td>-1.329999</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies SkinThickness Insulin BMI\n",
"450 -0.838489 -0.482325 0.136961 -1.329999"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: 0 (proba: [0.81353825 0.18646175])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 0'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 450\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"import pandas as pd\n",
"\n",
"\n",
"# Определяем числовые признаки\n",
"numeric_features = X_train.select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
"\n",
"# Установка random_state\n",
"random_state = 42\n",
"\n",
"# Определение трансформера\n",
"pipeline_end = ColumnTransformer([\n",
" ('numeric', StandardScaler(), numeric_features),\n",
" # Добавьте другие трансформеры, если требуется\n",
"])\n",
"\n",
"# Объявление модели\n",
"optimized_model = RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"sqrt\",\n",
" n_estimators=10,\n",
")\n",
"\n",
"# Создание пайплайна с корректными шагами\n",
"result = {}\n",
"\n",
"# Обучение модели\n",
"result[\"pipeline\"] = Pipeline([\n",
" (\"pipeline\", pipeline_end),\n",
" (\"model\", optimized_model)\n",
"]).fit(X_train, y_train.values.ravel())\n",
"\n",
"# Прогнозирование и расчет метрик\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"# Метрики для оценки модели\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
"execution_count": 106,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_bc935_row0_col0, #T_bc935_row0_col1, #T_bc935_row0_col2, #T_bc935_row0_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_bc935_row0_col4, #T_bc935_row0_col5, #T_bc935_row0_col6, #T_bc935_row0_col7 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_bc935_row1_col0, #T_bc935_row1_col1, #T_bc935_row1_col2, #T_bc935_row1_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_bc935_row1_col4, #T_bc935_row1_col5, #T_bc935_row1_col6, #T_bc935_row1_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_bc935\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_bc935_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_bc935_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_bc935_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_bc935_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_bc935_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_bc935_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_bc935_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_bc935_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" <th class=\"blank col7\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_bc935_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_bc935_row0_col0\" class=\"data row0 col0\" >0.955157</td>\n",
" <td id=\"T_bc935_row0_col1\" class=\"data row0 col1\" >0.535714</td>\n",
" <td id=\"T_bc935_row0_col2\" class=\"data row0 col2\" >0.995327</td>\n",
" <td id=\"T_bc935_row0_col3\" class=\"data row0 col3\" >0.555556</td>\n",
" <td id=\"T_bc935_row0_col4\" class=\"data row0 col4\" >0.982085</td>\n",
" <td id=\"T_bc935_row0_col5\" class=\"data row0 col5\" >0.675325</td>\n",
" <td id=\"T_bc935_row0_col6\" class=\"data row0 col6\" >0.974828</td>\n",
" <td id=\"T_bc935_row0_col7\" class=\"data row0 col7\" >0.545455</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_bc935_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_bc935_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_bc935_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_bc935_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_bc935_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_bc935_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_bc935_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_bc935_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_bc935_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x175bbfc33e0>"
]
},
"execution_count": 106,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_9cfcd_row0_col0, #T_9cfcd_row0_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9cfcd_row0_col2, #T_9cfcd_row0_col3, #T_9cfcd_row0_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9cfcd_row1_col0, #T_9cfcd_row1_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_9cfcd_row1_col2, #T_9cfcd_row1_col3, #T_9cfcd_row1_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_9cfcd\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_9cfcd_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_9cfcd_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_9cfcd_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_9cfcd_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_9cfcd_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_9cfcd_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_9cfcd_row0_col0\" class=\"data row0 col0\" >0.675325</td>\n",
" <td id=\"T_9cfcd_row0_col1\" class=\"data row0 col1\" >0.545455</td>\n",
" <td id=\"T_9cfcd_row0_col2\" class=\"data row0 col2\" >0.715093</td>\n",
" <td id=\"T_9cfcd_row0_col3\" class=\"data row0 col3\" >0.293059</td>\n",
" <td id=\"T_9cfcd_row0_col4\" class=\"data row0 col4\" >0.293176</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_9cfcd_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_9cfcd_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_9cfcd_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_9cfcd_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_9cfcd_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_9cfcd_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x175bc5757c0>"
]
},
"execution_count": 107,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 109,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA3MAAAGxCAYAAADI9u/sAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABVkUlEQVR4nO3deVyVZf7/8fdBVtkUUxADlzSX1NxKscU0imzTZKbRbEYrbXIrtdL8lUtWYn7HdCyXMpOaNMuazKwso9HJNbW0TCO3AhfQNEAwFjnn9wfjqRNoIDcczrlez8fjfozc5z73+RxzfPu5r+u+bpvD4XAIAAAAAOBRfNxdAAAAAACg4mjmAAAAAMAD0cwBAAAAgAeimQMAAAAAD0QzBwAAAAAeiGYOAAAAADwQzRwAAAAAeCCaOQAAAADwQDRzAAAAAOCBaOYAAAAAwAPRzAEAAABABfz3v//VbbfdpujoaNlsNq1YscLldYfDoUmTJqlhw4YKCgpSfHy89u7d63LMyZMnNXDgQIWFhalOnTq67777lJubW6E6aOYAAAAAoALy8vJ0+eWXa+7cuWW+PmPGDM2ZM0cLFizQli1bFBwcrISEBOXn5zuPGThwoL799lutWbNGq1at0n//+1/df//9FarD5nA4HJX6JgAAAABgKJvNpnfffVd9+/aVVDIqFx0drYcffliPPPKIJCk7O1uRkZFKTk5W//79tWfPHrVp00Zbt25Vly5dJEmrV6/WzTffrEOHDik6Orpcn+1bJd8IAFDj5Ofnq7Cw0LLz+fv7KzAw0LLzAQBQEVbnmsPhkM1mc9kXEBCggICACp3n4MGDysjIUHx8vHNfeHi4unbtqk2bNql///7atGmT6tSp42zkJCk+Pl4+Pj7asmWL7rjjjnJ9Fs0cABggPz9fTRuHKONYsWXnjIqK0sGDB2noAADVripyLSQkpNQ9a5MnT9aUKVMqdJ6MjAxJUmRkpMv+yMhI52sZGRlq0KCBy+u+vr6KiIhwHlMeNHMAYIDCwkJlHCvWwe2NFRZa+dulc07Z1bTzjyosLKSZAwBUu6rKtfT0dIWFhTn3V3RUrrrRzAGAQcJCfSwJPQAAagKrcy0sLMylmbsQUVFRkqTMzEw1bNjQuT8zM1MdOnRwHnPs2DGX9505c0YnT550vr88SHQAMEixw27ZBgCAu9XEXGvatKmioqKUkpLi3JeTk6MtW7YoLi5OkhQXF6esrCxt377decxnn30mu92url27lvuzGJkDAIPY5ZBdlV/E2IpzAABQWe7KtdzcXO3bt8/588GDB7Vjxw5FREQoNjZWo0eP1tNPP60WLVqoadOmmjhxoqKjo50rXrZu3Vo33XSThg4dqgULFqioqEgjR45U//79y72SpUQzBwAAAAAVsm3bNvXs2dP589ixYyVJgwYNUnJyssaNG6e8vDzdf//9ysrK0tVXX63Vq1e73Ge+ZMkSjRw5Utdff718fHyUmJioOXPmVKgOnjMHAAbIyclReHi4jqRebNmN4tEtDyk7O7vS9xYAAFBR5FoJRuYAwCDFDoeKLbiGZ8U5AACoLNNzjQVQAAAAAMADMTIHAAZhARQAgDcxPddo5gDAIHY5VGxw6AEAvIvpucY0SwAAAADwQIzMAYBBTJ+OAgDwLqbnGiNzAAAAAOCBGJkDAIOYvoQzAMC7mJ5rNHMAYBD7/zYrzgMAgLuZnmtMswQAAAAAD8TIHAAYpNiiJZytOAcAAJVleq7RzAGAQYodJZsV5wEAwN1MzzWmWQIAAACAB2JkDgAMYvqN4gAA72J6rtHMAYBB7LKpWDZLzgMAgLuZnmtMswQAAAAAD8TIHAAYxO4o2aw4DwAA7mZ6rjEyBwAAAAAeiJE5ADBIsUX3FlhxDgAAKsv0XKOZAwCDmB56AADvYnquMc0SAAAAADwQzRwAGMTusFm2lVeTJk1ks9lKbSNGjJAk5efna8SIEapXr55CQkKUmJiozMzMqvotAAB4EXfkWk1CMwcABjk7HcWKrby2bt2qo0ePOrc1a9ZIkv785z9LksaMGaP3339fy5cv17p163TkyBH169evSr4/AMC7uCPXahLumQMAVKn69eu7/Dx9+nRdcskl6tGjh7Kzs7Vo0SItXbpUvXr1kiQtXrxYrVu31ubNm9WtWzd3lAwAgEegmQMAgxTLR8UWTMoo/t//5uTkuOwPCAhQQEDAOd9XWFio119/XWPHjpXNZtP27dtVVFSk+Ph45zGtWrVSbGysNm3aRDMHADgvq3PN0zDNEgAM4rDovgLH/+4tiImJUXh4uHNLSko67+evWLFCWVlZGjx4sCQpIyND/v7+qlOnjstxkZGRysjIqIrfAgCAF7E61zwNI3MAgAuWnp6usLAw58/nG5WTpEWLFql3796Kjo6u6tIAAPB6NHMAYBCrn8cTFhbm0sydz48//qhPP/1U//73v537oqKiVFhYqKysLJfRuczMTEVFRVW6TgCAd+M5cwAAYxQ7fCzbKmrx4sVq0KCBbrnlFue+zp07y8/PTykpKc59qampSktLU1xcnCXfGQDgvdyZazUBI3MAgCpnt9u1ePFiDRo0SL6+v0ZPeHi47rvvPo0dO1YREREKCwvTqFGjFBcXx+InAAD8AZo5ADCIXTbZLZiUYZejQsd/+umnSktL07333lvqtVmzZsnHx0eJiYkqKChQQkKC5s2bV+kaAQDez125VlPQzAEAqtyNN94oh6PsoAwMDNTcuXM1d+7caq4KAADPRjMHAAYx/UZxAIB3MT3XaOYAwCBW3eRdfI5RNgAAqpPpueaZy7YAAAAAgOEYmQMAg5TcKF75qSRWnAMAgMoyPddo5gDAIHb5qNjgVb8AAN7F9FxjmiUAAAAAeCBG5gDAIKbfKA4A8C6m5xrNHAAYxC4fox+uCgDwLqbnGtMsAQAAAMADMTIHAAYpdthU7LDg4aoWnAMAgMoyPdcYmQMAAAAAD8TIHAAYpNiiJZyLPfTeAgCAdzE912jmAMAgdoeP7Bas+mX30FW/AADexfRcY5olAAAAAHggRuYAwCCmT0cBAHgX03ONZg4ADGKXNSt22StfCgAAlWZ6rjHNEgAAAAA8ECNzAGAQu3xkt+A6nhXnAACgskzPNZo5ADBIscNHxRas+mXFOQAAqCzTc80zqwYAAAAAwzEyBwAGscsmu6y4Ubzy5wAAoLJMzzWaOQAwiOnTUQAA3sX0XPPMqgEAAADAcIzMAYBBrHu4KtcCAQDuZ3queWbVAAAAAGA4Ruaqmd1u15EjRxQaGiqbzTNvtARQvRwOh06dOqXo6Gj5+FTuGpzdYZPdYcGN4hacA96BXANQUeSadWjmqtmRI0cUExPj7jIAeKD09HRdfPHFlTqH3aLpKJ76cFVYj1wDcKHItcqjmatmoaGhkqQfv2yisBDP/EODqvOnLt3dXQJqoDOOQq079Zbz7w+gJiHXcD53XNrO3SWgBjqjIq3Xh+SaBWjmqtnZKShhIT4KCyX04MrX5u/uElCDWTGFze7wkd2C5ZetOAe8A7mG8/G1+bm7BNREjpL/Idcqj2YOAAxSLJuKLXgwqhXnAACgskzPNc9sQQEAAADAcIzMAYBBTJ+OAgDwLqbnGs0cABikWNZMJSmufCkAAFSa6bnmmS0oAAAAABiOkTkAMIjp01EAAN7F9FzzzKoBAAAAwHCMzAGAQYodPiq24OqjFecAAKCyTM81z6waAHBBHLLJbsHm8NDn8QAAvIs7cq24uFgTJ05U06ZNFRQUpEsuuURPPfWUHA7Hr3U5HJo0aZIaNmyooKAgxcfHa+/evZZ/f5o5AAAAACinZ599VvPnz9cLL7ygPXv26Nlnn9WMGTP0/PPPO4+ZMWOG5syZowULFmjLli0KDg5WQkKC8vPzLa2FaZYAYBDTp6MAALyLO3Jt48aN6tOnj2655RZJUpMmTfTGG2/oiy++kFQyKjd79mw98cQT6tOnjyTptddeU2RkpFasWKH+/ftXut6zSGMAMIjdYbNsAwDA3azOtZycHJetoKCg1Gd2795dKSkp+v777yVJO3fu1Pr169W7d29
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"\n",
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Healthy\", \"Sick\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В желтом квадрате мы видим значение 74, что обозначает количество правильно классифицированных объектов, отнесенных к классу \"Sick\". Это свидетельствует о том, что модель успешно идентифицирует объекты этого класса, минимизируя количество ложных положительных срабатываний.\n",
"\n",
"В зеленом квадрате значение 54 указывает на количество правильно классифицированных объектов, отнесенных к классу \"Healthy\". Это также является показателем хорошей точности модели в определении объектов данного класса."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Определение достижимого уровня качества модели для второй задачи"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подготовка данных"
]
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin \\\n",
"count 768.000000 768.000000 768.000000 768.000000 768.000000 \n",
"mean 3.845052 120.894531 69.105469 20.536458 79.799479 \n",
"std 3.369578 31.972618 19.355807 15.952218 115.244002 \n",
"min 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
"25% 1.000000 99.000000 62.000000 0.000000 0.000000 \n",
"50% 3.000000 117.000000 72.000000 23.000000 30.500000 \n",
"75% 6.000000 140.250000 80.000000 32.000000 127.250000 \n",
"max 17.000000 199.000000 122.000000 99.000000 846.000000 \n",
"\n",
" BMI DiabetesPedigreeFunction Age Outcome \n",
"count 768.000000 768.000000 768.000000 768.000000 \n",
"mean 31.992578 0.471876 33.240885 0.348958 \n",
"std 7.884160 0.331329 11.760232 0.476951 \n",
"min 0.000000 0.078000 21.000000 0.000000 \n",
"25% 27.300000 0.243750 24.000000 0.000000 \n",
"50% 32.000000 0.372500 29.000000 0.000000 \n",
"75% 36.600000 0.626250 41.000000 1.000000 \n",
"max 67.100000 2.420000 81.000000 1.000000 \n"
]
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import set_config\n",
"\n",
"\n",
"random_state = 9\n",
"set_config(transform_output=\"pandas\")\n",
"df = pd.read_csv(\".//scv//diabetes.csv\")\n",
"print(df.describe())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование выборок"
]
},
{
"cell_type": "code",
"execution_count": 112,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>196</th>\n",
" <td>1</td>\n",
" <td>105</td>\n",
" <td>58</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>24.3</td>\n",
" <td>0.187</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69</th>\n",
" <td>4</td>\n",
" <td>146</td>\n",
" <td>85</td>\n",
" <td>27</td>\n",
" <td>100</td>\n",
" <td>28.9</td>\n",
" <td>0.189</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>494</th>\n",
" <td>3</td>\n",
" <td>80</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.174</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>463</th>\n",
" <td>5</td>\n",
" <td>88</td>\n",
" <td>78</td>\n",
" <td>30</td>\n",
" <td>0</td>\n",
" <td>27.6</td>\n",
" <td>0.258</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>653</th>\n",
" <td>2</td>\n",
" <td>120</td>\n",
" <td>54</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>26.8</td>\n",
" <td>0.455</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>322</th>\n",
" <td>0</td>\n",
" <td>124</td>\n",
" <td>70</td>\n",
" <td>20</td>\n",
" <td>0</td>\n",
" <td>27.4</td>\n",
" <td>0.254</td>\n",
" <td>36</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>109</th>\n",
" <td>0</td>\n",
" <td>95</td>\n",
" <td>85</td>\n",
" <td>25</td>\n",
" <td>36</td>\n",
" <td>37.4</td>\n",
" <td>0.247</td>\n",
" <td>24</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>1</td>\n",
" <td>97</td>\n",
" <td>66</td>\n",
" <td>15</td>\n",
" <td>140</td>\n",
" <td>23.2</td>\n",
" <td>0.487</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>651</th>\n",
" <td>1</td>\n",
" <td>117</td>\n",
" <td>60</td>\n",
" <td>23</td>\n",
" <td>106</td>\n",
" <td>33.8</td>\n",
" <td>0.466</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>197</th>\n",
" <td>3</td>\n",
" <td>107</td>\n",
" <td>62</td>\n",
" <td>13</td>\n",
" <td>48</td>\n",
" <td>22.9</td>\n",
" <td>0.678</td>\n",
" <td>23</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"196 1 105 58 0 0 24.3 \n",
"69 4 146 85 27 100 28.9 \n",
"494 3 80 0 0 0 0.0 \n",
"463 5 88 78 30 0 27.6 \n",
"653 2 120 54 0 0 26.8 \n",
".. ... ... ... ... ... ... \n",
"322 0 124 70 20 0 27.4 \n",
"109 0 95 85 25 36 37.4 \n",
"27 1 97 66 15 140 23.2 \n",
"651 1 117 60 23 106 33.8 \n",
"197 3 107 62 13 48 22.9 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"196 0.187 21 0 \n",
"69 0.189 27 0 \n",
"494 0.174 22 0 \n",
"463 0.258 37 0 \n",
"653 0.455 27 0 \n",
".. ... ... ... \n",
"322 0.254 36 1 \n",
"109 0.247 24 1 \n",
"27 0.487 22 0 \n",
"651 0.466 27 0 \n",
"197 0.678 23 1 \n",
"\n",
"[614 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>196</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>494</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>463</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>653</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>322</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>109</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>651</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>197</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Outcome\n",
"196 0\n",
"69 0\n",
"494 0\n",
"463 0\n",
"653 0\n",
".. ...\n",
"322 1\n",
"109 1\n",
"27 0\n",
"651 0\n",
"197 1\n",
"\n",
"[614 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>669</th>\n",
" <td>9</td>\n",
" <td>154</td>\n",
" <td>78</td>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>30.9</td>\n",
" <td>0.164</td>\n",
" <td>45</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>379</th>\n",
" <td>0</td>\n",
" <td>93</td>\n",
" <td>100</td>\n",
" <td>39</td>\n",
" <td>72</td>\n",
" <td>43.4</td>\n",
" <td>1.021</td>\n",
" <td>35</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>640</th>\n",
" <td>0</td>\n",
" <td>102</td>\n",
" <td>86</td>\n",
" <td>17</td>\n",
" <td>105</td>\n",
" <td>29.3</td>\n",
" <td>0.695</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>658</th>\n",
" <td>11</td>\n",
" <td>127</td>\n",
" <td>106</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>39.0</td>\n",
" <td>0.190</td>\n",
" <td>51</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>304</th>\n",
" <td>3</td>\n",
" <td>150</td>\n",
" <td>76</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>21.0</td>\n",
" <td>0.207</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203</th>\n",
" <td>2</td>\n",
" <td>99</td>\n",
" <td>70</td>\n",
" <td>16</td>\n",
" <td>44</td>\n",
" <td>20.4</td>\n",
" <td>0.235</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>605</th>\n",
" <td>1</td>\n",
" <td>124</td>\n",
" <td>60</td>\n",
" <td>32</td>\n",
" <td>0</td>\n",
" <td>35.8</td>\n",
" <td>0.514</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>561</th>\n",
" <td>0</td>\n",
" <td>198</td>\n",
" <td>66</td>\n",
" <td>32</td>\n",
" <td>274</td>\n",
" <td>41.3</td>\n",
" <td>0.502</td>\n",
" <td>28</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>280</th>\n",
" <td>0</td>\n",
" <td>146</td>\n",
" <td>70</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>37.9</td>\n",
" <td>0.334</td>\n",
" <td>28</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>103</th>\n",
" <td>1</td>\n",
" <td>81</td>\n",
" <td>72</td>\n",
" <td>18</td>\n",
" <td>40</td>\n",
" <td>26.6</td>\n",
" <td>0.283</td>\n",
" <td>24</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>154 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"669 9 154 78 30 100 30.9 \n",
"379 0 93 100 39 72 43.4 \n",
"640 0 102 86 17 105 29.3 \n",
"658 11 127 106 0 0 39.0 \n",
"304 3 150 76 0 0 21.0 \n",
".. ... ... ... ... ... ... \n",
"203 2 99 70 16 44 20.4 \n",
"605 1 124 60 32 0 35.8 \n",
"561 0 198 66 32 274 41.3 \n",
"280 0 146 70 0 0 37.9 \n",
"103 1 81 72 18 40 26.6 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"669 0.164 45 0 \n",
"379 1.021 35 0 \n",
"640 0.695 27 0 \n",
"658 0.190 51 0 \n",
"304 0.207 37 0 \n",
".. ... ... ... \n",
"203 0.235 27 0 \n",
"605 0.514 21 0 \n",
"561 0.502 28 1 \n",
"280 0.334 28 1 \n",
"103 0.283 24 0 \n",
"\n",
"[154 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>669</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>379</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>640</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>658</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>304</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>605</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>561</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>280</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>103</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>154 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Outcome\n",
"669 0\n",
"379 0\n",
"640 0\n",
"658 0\n",
"304 0\n",
".. ...\n",
"203 0\n",
"605 0\n",
"561 1\n",
"280 1\n",
"103 0\n",
"\n",
"[154 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input: DataFrame,\n",
" stratify_colname: str = \"y\",\n",
" frac_train: float = 0.6,\n",
" frac_val: float = 0.15,\n",
" frac_test: float = 0.25,\n",
" random_state: int = None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
"\n",
" if not (0 < frac_train < 1) or not (0 <= frac_val <= 1) or not (0 <= frac_test <= 1):\n",
" raise ValueError(\"Fractions must be between 0 and 1 and the sum must equal 1.\")\n",
" \n",
" if not (frac_train + frac_val + frac_test == 1.0):\n",
" raise ValueError(\"fractions %f, %f, %f do not add up to 1.0\" %\n",
" (frac_train, frac_val, frac_test))\n",
"\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(f\"{stratify_colname} is not a column in the DataFrame.\")\n",
"\n",
" X = df_input\n",
" y = df_input[[stratify_colname]]\n",
"\n",
" \n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
"\n",
" if frac_val == 0:\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
"\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
"\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
"\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" \n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"Outcome\", frac_train=0.80, frac_val=0.0, frac_test=0.20, random_state=random_state\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование конвейера для классификации данных"
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"class DiabetFeatures(BaseEstimator, TransformerMixin):\n",
" def __init__(self):\n",
" pass\n",
" def fit(self, X, y=None):\n",
" return self\n",
" \n",
"\n",
"columns_to_drop = [\"Pregnancies\", \"SkinThickness\", \"Insulin\", \"BMI\"]\n",
"num_columns = [\"Glucose\", \"Age\", \"BloodPressure\", \"Outcome\", \"DiabetesPedigreeFunction\"]\n",
"cat_columns = []\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"features_postprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_cat\", preprocessing_cat, [\"Cabin_type\"]),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Демонстрация работы конвейера"
]
},
{
"cell_type": "code",
"execution_count": 122,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Glucose</th>\n",
" <th>Age</th>\n",
" <th>BloodPressure</th>\n",
" <th>Outcome</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>196</th>\n",
" <td>-0.478144</td>\n",
" <td>-1.029257</td>\n",
" <td>-0.554050</td>\n",
" <td>-0.731437</td>\n",
" <td>-0.849205</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69</th>\n",
" <td>0.818506</td>\n",
" <td>-0.522334</td>\n",
" <td>0.804885</td>\n",
" <td>-0.731437</td>\n",
" <td>-0.843172</td>\n",
" </tr>\n",
" <tr>\n",
" <th>494</th>\n",
" <td>-1.268784</td>\n",
" <td>-0.944770</td>\n",
" <td>-3.473244</td>\n",
" <td>-0.731437</td>\n",
" <td>-0.888421</td>\n",
" </tr>\n",
" <tr>\n",
" <th>463</th>\n",
" <td>-1.015779</td>\n",
" <td>0.322537</td>\n",
" <td>0.452568</td>\n",
" <td>-0.731437</td>\n",
" <td>-0.635028</td>\n",
" </tr>\n",
" <tr>\n",
" <th>653</th>\n",
" <td>-0.003760</td>\n",
" <td>-0.522334</td>\n",
" <td>-0.755374</td>\n",
" <td>-0.731437</td>\n",
" <td>-0.040763</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>322</th>\n",
" <td>0.122742</td>\n",
" <td>0.238050</td>\n",
" <td>0.049921</td>\n",
" <td>1.367172</td>\n",
" <td>-0.647095</td>\n",
" </tr>\n",
" <tr>\n",
" <th>109</th>\n",
" <td>-0.794400</td>\n",
" <td>-0.775796</td>\n",
" <td>0.804885</td>\n",
" <td>1.367172</td>\n",
" <td>-0.668211</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>-0.731149</td>\n",
" <td>-0.944770</td>\n",
" <td>-0.151403</td>\n",
" <td>-0.731437</td>\n",
" <td>0.055767</td>\n",
" </tr>\n",
" <tr>\n",
" <th>651</th>\n",
" <td>-0.098637</td>\n",
" <td>-0.522334</td>\n",
" <td>-0.453388</td>\n",
" <td>-0.731437</td>\n",
" <td>-0.007581</td>\n",
" </tr>\n",
" <tr>\n",
" <th>197</th>\n",
" <td>-0.414893</td>\n",
" <td>-0.860283</td>\n",
" <td>-0.352726</td>\n",
" <td>1.367172</td>\n",
" <td>0.631933</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 5 columns</p>\n",
"</div>"
],
"text/plain": [
" Glucose Age BloodPressure Outcome DiabetesPedigreeFunction\n",
"196 -0.478144 -1.029257 -0.554050 -0.731437 -0.849205\n",
"69 0.818506 -0.522334 0.804885 -0.731437 -0.843172\n",
"494 -1.268784 -0.944770 -3.473244 -0.731437 -0.888421\n",
"463 -1.015779 0.322537 0.452568 -0.731437 -0.635028\n",
"653 -0.003760 -0.522334 -0.755374 -0.731437 -0.040763\n",
".. ... ... ... ... ...\n",
"322 0.122742 0.238050 0.049921 1.367172 -0.647095\n",
"109 -0.794400 -0.775796 0.804885 1.367172 -0.668211\n",
"27 -0.731149 -0.944770 -0.151403 -0.731437 0.055767\n",
"651 -0.098637 -0.522334 -0.453388 -0.731437 -0.007581\n",
"197 -0.414893 -0.860283 -0.352726 1.367172 0.631933\n",
"\n",
"[614 rows x 5 columns]"
]
},
"execution_count": 122,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование набора моделей для классификации"
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
"\n",
"random_state = 9\n",
"\n",
"models = {\n",
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPRegressor(\n",
" activation=\"tanh\",\n",
" hidden_layer_sizes=(3,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение моделей на обучающем наборе данных и оценка на тестовом¶"
]
},
{
"cell_type": "code",
"execution_count": 124,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Сводная таблица оценок качества для использованных моделей классификации¶\n",
"\n",
"Матрица неточностей\n"
]
},
{
"cell_type": "code",
"execution_count": 125,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAz8AAAQ9CAYAAAB3OvPGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeXxM5/4H8M/JNonsISsRkdhSW1CqoZRUqFourUtpE5RW7YpyW/sS9NdWqaVVgl6utoqri6iqUKquvUXEFsSSWCKJhGwzz++P1LTTJJNMnMnJnPm87+u8rjzPyZnnRHM+vuc5iySEECAiIiIiIlI5G6UHQEREREREVBlY/BARERERkVVg8UNERERERFaBxQ8REREREVkFFj9ERERERGQVWPwQEREREZFVYPFDRERERERWgcUPERERERFZBRY/RERERERkFVj8UIWsXbsWkiTh8uXLZtn+5cuXIUkS1q5dK8v2EhISIEkSEhISZNkeERGRWsycOROSJJVrXUmSMHPmTPMOiMiMWPyQqixfvly2gomIiIiI1MVO6QEQlSQoKAgPHz6Evb29Sd+3fPly1KhRAzExMQbtzzzzDB4+fAgHBwcZR0lERGT53n33XUyZMkXpYRBVChY/VCVJkgRHR0fZtmdjYyPr9oiIiNQgJycHzs7OsLPjPwnJOvCyN5LN8uXL8cQTT0Cj0SAgIAAjR45ERkZGsfWWLVuGunXrwsnJCa1bt8bPP/+Mjh07omPHjvp1SrrnJzU1FYMHD0atWrWg0Wjg7++PXr166e87qlOnDk6fPo29e/dCkiRIkqTfZmn3/Bw6dAjPP/88PD094ezsjKZNm+Kjjz6S9wdDRERUBTy6t+fMmTN4+eWX4enpiXbt2pV4z09eXh7Gjx8Pb29vuLq6omfPnrh27VqJ201ISECrVq3g6OiIkJAQfPLJJ6XeR/Tvf/8bLVu2hJOTE7y8vNC/f3+kpKSYZX+JSsIyn2Qxc+ZMzJo1C5GRkRgxYgSSkpKwYsUKHD58GAcOHNBfvrZixQqMGjUK7du3x/jx43H58mX07t0bnp6eqFWrltHP6Nu3L06fPo3Ro0ejTp06uHXrFnbt2oWrV6+iTp06WLx4MUaPHg0XFxe88847AABfX99St7dr1y688MIL8Pf3x9ixY+Hn54fExER8++23GDt2rHw/HCIioirkpZdeQr169TB//nwIIXDr1q1i67z22mv497//jZdffhlPP/00fvrpJ3Tv3r3YesePH0fXrl3h7++PWbNmQavVYvbs2fD29i627rx58zBt2jT069cPr732Gm7fvo2lS5fimWeewfHjx+Hh4WGO3SUyJIgqIC4uTgAQycnJ4tatW8LBwUF06dJFaLVa/Toff/yxACDWrFkjhBAiLy9PVK9eXTz55JOioKBAv97atWsFANGhQwd9W3JysgAg4uLihBBC3Lt3TwAQ7733ntFxPfHEEwbbeWTPnj0CgNizZ48QQojCwkIRHBwsgoKCxL179wzW1el05f9BEBERWYgZM2YIAGLAgAEltj9y4sQJAUC8+eabBuu9/PLLAoCYMWOGvq1Hjx6iWrVq4vr16/q28+fPCzs7O4NtXr58Wdja2op58+YZbPP3338XdnZ2xdqJzIWXvdFj+/HHH5Gfn49x48bBxubP/6SGDRsGNzc3fPfddwCAI0eO4O7duxg2bJjBtcUDBw6Ep6en0c9wcnKCg4MDEhIScO/evcce8/Hjx5GcnIxx48YVO9NU3sd9EhERWaI33njDaP/3338PABgzZoxB+7hx4wy+1mq1+PHHH9G7d28EBATo20NDQ9GtWzeDdbds2QKdTod+/frhzp07+sXPzw/16tXDnj17HmOPiMqPl73RY7ty5QoAoEGDBgbtDg4OqFu3rr7/0f+HhoYarGdnZ4c6deoY/QyNRoOFCxfirbfegq+vL5566im88MILePXVV+Hn52fymC9evAgAaNy4scnfS0REZMmCg4ON9l+5cgU2NjYICQkxaP97zt+6dQsPHz4slutA8aw/f/48hBCoV69eiZ9p6tNdiSqKxQ9ZjHHjxqFHjx7Ytm0bdu7ciWnTpiE2NhY//fQTwsPDlR4eERGRRXBycqr0z9TpdJAkCTt27ICtrW2xfhcXl0ofE1knXvZGjy0oKAgAkJSUZNCen5+P5ORkff+j/79w4YLBeoWFhfontpUlJCQEb731Fn744QecOnUK+fn5eP/99/X95b1k7dHZrFOnTpVrfSIiImsRFBQEnU6nv0rikb/nvI+PDxwdHYvlOlA860NCQiCEQHBwMCIjI4stTz31lPw7QlQCFj/02CIjI+Hg4IAlS5ZACKFvX716NTIzM/VPh2nVqhWqV6+OVatWobCwUL/ehg0byryP58GDB8jNzTVoCwkJgaurK/Ly8vRtzs7OJT5e++9atGiB4OBgLF68uNj6f90HIiIia/Pofp0lS5YYtC9evNjga1tbW0RGRmLbtm24ceOGvv3ChQvYsWOHwbp9+vSBra0tZs2aVSxnhRC4e/eujHtAVDpe9kaPzdvbG1OnTsWsWbPQtWtX9OzZE0lJSVi+fDmefPJJDBo0CEDRPUAzZ87E6NGj0alTJ/Tr1w+XL1/G2rVrERISYnTW5ty5c+jcuTP69euHsLAw2NnZYevWrUhLS0P//v3167Vs2RIrVqzA3LlzERoaCh8fH3Tq1KnY9mxsbLBixQr06NEDzZs3x+DBg+Hv74+zZ8/i9OnT2Llzp/w/KCIiIgvQvHlzDBgwAMuXL0dmZiaefvpp7N69u8QZnpkzZ+KHH35AREQERowYAa1Wi48//hiNGzfGiRMn9OuFhIRg7ty5mDp1qv41F66urkhOTsbWrVsxfPhwTJw4sRL3kqwVix+SxcyZM+Ht7Y2PP/4Y48ePh5eXF4YPH4758+cb3MQ4atQoCCHw/vvvY+LEiWjWrBm2b9+OMWPGwNHRsdTtBwYGYsCAAdi9ezc+//xz2NnZoWHDhvjyyy/Rt29f/XrTp0/HlStXsGjRIty/fx8dOnQosfgBgKioKOzZswezZs3C+++/D51Oh5CQEAwbNky+HwwREZEFWrNmDby9vbFhwwZs27YNnTp1wnfffYfAwECD9Vq2bIkdO3Zg4sSJmDZtGgIDAzF79mwkJibi7NmzButOmTIF9evXx4cffohZs2YBKMr3Ll26oGfPnpW2b2TdJMFrfEhhOp0O3t7e6NOnD1atWqX0cIiIiOgx9e7dG6dPn8b58+eVHgqRAd7zQ5UqNze32LW+69evR3p6Ojp27KjMoIiIiKjCHj58aPD1+fPn8f333zPXqUrizA9VqoSEBIwfPx4vvfQSqlevjmPHjmH16tVo1KgRjh49CgcHB6WHSERERCbw9/dHTEyM/t1+K1asQF5eHo4fP17qe32IlMJ7fqhS1alTB4GBgViyZAnS09Ph5eWFV199FQsWLGDhQ0REZIG6du2K//znP0hNTYVGo0Hbtm0xf/58Fj5UJfGyN6pUderUwfbt25Gamor8/HykpqZizZo18PHxUXpopBL79u1Djx49EBAQAEmSsG3bNoN+IQSmT58Of39/ODk5ITIystg16enp6Rg4cCDc3Nzg4eGBoUOHIjs7uxL3gojIcsTFxeHy5cvIzc1FZmYm4uPj0aJFC6WHRVVIVcpmFj9EpCo5OTlo1qwZli1bVmL/okWLsGTJEqxcuRKHDh2Cs7MzoqKiDN4jNXDgQJw+fRq7du3Ct99+i3379mH48OGVtQtERESqUpWymff8EJFqSZKErVu3onfv3gCKziwFBATgrbfe0r9PIjMzE76+vli7di369++PxMREhIWF4fDhw2jVqhUAID4+Hs8//zyuXbuGgIAApXaHiIjI4imdzbznpxx0Oh1u3LgBV1dXoy/iJFIjIQTu37+PgIAA2NjIO1mcm5uL/Pz8co3h7797Go0GGo3GpM9LTk5GamoqIiMj9W3u7u5o06YNDh48iP79++PgwYPw8PDQH1wBIDIyEjY2Njh06BD+8Y9/mPSZRGQezGayZszmimczi59yuHHjRrGXehFZm5S
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Less\", \"More\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 126,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_ccbd1_row0_col0, #T_ccbd1_row0_col1, #T_ccbd1_row0_col2, #T_ccbd1_row0_col3, #T_ccbd1_row1_col0, #T_ccbd1_row1_col1, #T_ccbd1_row1_col2, #T_ccbd1_row1_col3, #T_ccbd1_row2_col0, #T_ccbd1_row2_col1, #T_ccbd1_row2_col2, #T_ccbd1_row2_col3, #T_ccbd1_row3_col0, #T_ccbd1_row3_col1, #T_ccbd1_row3_col2, #T_ccbd1_row3_col3, #T_ccbd1_row4_col0, #T_ccbd1_row4_col1, #T_ccbd1_row4_col2, #T_ccbd1_row4_col3, #T_ccbd1_row5_col0, #T_ccbd1_row5_col1, #T_ccbd1_row5_col2, #T_ccbd1_row5_col3, #T_ccbd1_row6_col0, #T_ccbd1_row6_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_ccbd1_row0_col4, #T_ccbd1_row0_col5, #T_ccbd1_row0_col6, #T_ccbd1_row0_col7, #T_ccbd1_row1_col4, #T_ccbd1_row1_col5, #T_ccbd1_row1_col6, #T_ccbd1_row1_col7, #T_ccbd1_row2_col4, #T_ccbd1_row2_col5, #T_ccbd1_row2_col6, #T_ccbd1_row2_col7, #T_ccbd1_row3_col4, #T_ccbd1_row3_col5, #T_ccbd1_row3_col6, #T_ccbd1_row3_col7, #T_ccbd1_row4_col4, #T_ccbd1_row4_col5, #T_ccbd1_row4_col6, #T_ccbd1_row4_col7, #T_ccbd1_row5_col4, #T_ccbd1_row5_col5, #T_ccbd1_row5_col6, #T_ccbd1_row5_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_ccbd1_row6_col1 {\n",
" background-color: #98d83e;\n",
" color: #000000;\n",
"}\n",
"#T_ccbd1_row6_col2 {\n",
" background-color: #a2da37;\n",
" color: #000000;\n",
"}\n",
"#T_ccbd1_row6_col4, #T_ccbd1_row6_col6 {\n",
" background-color: #d9586a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_ccbd1_row6_col5 {\n",
" background-color: #d7566c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_ccbd1_row6_col7 {\n",
" background-color: #d8576b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_ccbd1_row7_col0, #T_ccbd1_row7_col1, #T_ccbd1_row7_col2, #T_ccbd1_row7_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_ccbd1_row7_col4, #T_ccbd1_row7_col5, #T_ccbd1_row7_col6, #T_ccbd1_row7_col7 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_ccbd1\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_ccbd1_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_ccbd1_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_ccbd1_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_ccbd1_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_ccbd1_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_ccbd1_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_ccbd1_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_ccbd1_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_ccbd1_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_ccbd1_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_ccbd1_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_ccbd1_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_ccbd1_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_ccbd1_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_ccbd1_level0_row3\" class=\"row_heading level0 row3\" >naive_bayes</th>\n",
" <td id=\"T_ccbd1_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row3_col5\" class=\"data row3 col5\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row3_col6\" class=\"data row3 col6\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row3_col7\" class=\"data row3 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_ccbd1_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
" <td id=\"T_ccbd1_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row4_col5\" class=\"data row4 col5\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row4_col6\" class=\"data row4 col6\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row4_col7\" class=\"data row4 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_ccbd1_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
" <td id=\"T_ccbd1_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row5_col5\" class=\"data row5 col5\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row5_col6\" class=\"data row5 col6\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row5_col7\" class=\"data row5 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_ccbd1_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
" <td id=\"T_ccbd1_row6_col0\" class=\"data row6 col0\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row6_col1\" class=\"data row6 col1\" >0.981818</td>\n",
" <td id=\"T_ccbd1_row6_col2\" class=\"data row6 col2\" >0.990654</td>\n",
" <td id=\"T_ccbd1_row6_col3\" class=\"data row6 col3\" >1.000000</td>\n",
" <td id=\"T_ccbd1_row6_col4\" class=\"data row6 col4\" >0.996743</td>\n",
" <td id=\"T_ccbd1_row6_col5\" class=\"data row6 col5\" >0.993506</td>\n",
" <td id=\"T_ccbd1_row6_col6\" class=\"data row6 col6\" >0.995305</td>\n",
" <td id=\"T_ccbd1_row6_col7\" class=\"data row6 col7\" >0.990826</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_ccbd1_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_ccbd1_row7_col0\" class=\"data row7 col0\" >0.729323</td>\n",
" <td id=\"T_ccbd1_row7_col1\" class=\"data row7 col1\" >0.685714</td>\n",
" <td id=\"T_ccbd1_row7_col2\" class=\"data row7 col2\" >0.453271</td>\n",
" <td id=\"T_ccbd1_row7_col3\" class=\"data row7 col3\" >0.444444</td>\n",
" <td id=\"T_ccbd1_row7_col4\" class=\"data row7 col4\" >0.750814</td>\n",
" <td id=\"T_ccbd1_row7_col5\" class=\"data row7 col5\" >0.733766</td>\n",
" <td id=\"T_ccbd1_row7_col6\" class=\"data row7 col6\" >0.559078</td>\n",
" <td id=\"T_ccbd1_row7_col7\" class=\"data row7 col7\" >0.539326</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x175b9062330>"
]
},
"execution_count": 126,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"\n",
"Почти все модели, включая логистическую регрессию, ридж-регрессию, KNN, наивный байесовский классификатор, многослойную перцептронную сеть, случайный лес, дерево решений и градиентный бустинг, демонстрируют 100% точность (1.000000) на обучающей выборке. Это указывает на то, что модели смогли подстроиться под обучающие данные, что может указывать на возможное переобучение.\n",
"\n",
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса\n"
]
},
{
"cell_type": "code",
"execution_count": 129,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_bf885_row0_col0, #T_bf885_row0_col1, #T_bf885_row1_col0, #T_bf885_row1_col1, #T_bf885_row2_col0, #T_bf885_row2_col1, #T_bf885_row4_col0, #T_bf885_row4_col1, #T_bf885_row5_col0, #T_bf885_row5_col1, #T_bf885_row6_col0, #T_bf885_row6_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_bf885_row0_col2, #T_bf885_row0_col3, #T_bf885_row0_col4, #T_bf885_row1_col2, #T_bf885_row1_col3, #T_bf885_row1_col4, #T_bf885_row2_col2, #T_bf885_row2_col3, #T_bf885_row2_col4, #T_bf885_row3_col2, #T_bf885_row4_col2, #T_bf885_row4_col3, #T_bf885_row4_col4, #T_bf885_row5_col2, #T_bf885_row5_col3, #T_bf885_row5_col4, #T_bf885_row6_col2, #T_bf885_row6_col3, #T_bf885_row6_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_bf885_row3_col0 {\n",
" background-color: #a0da39;\n",
" color: #000000;\n",
"}\n",
"#T_bf885_row3_col1 {\n",
" background-color: #a2da37;\n",
" color: #000000;\n",
"}\n",
"#T_bf885_row3_col3, #T_bf885_row3_col4 {\n",
" background-color: #d8576b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_bf885_row7_col0, #T_bf885_row7_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_bf885_row7_col2, #T_bf885_row7_col3, #T_bf885_row7_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_bf885\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_bf885_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_bf885_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_bf885_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_bf885_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_bf885_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_bf885_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_bf885_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_bf885_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_bf885_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_bf885_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_bf885_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_bf885_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_bf885_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_bf885_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_bf885_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_bf885_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_bf885_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_bf885_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_bf885_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_bf885_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_bf885_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_bf885_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_bf885_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_bf885_level0_row3\" class=\"row_heading level0 row3\" >knn</th>\n",
" <td id=\"T_bf885_row3_col0\" class=\"data row3 col0\" >0.993506</td>\n",
" <td id=\"T_bf885_row3_col1\" class=\"data row3 col1\" >0.990826</td>\n",
" <td id=\"T_bf885_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_bf885_row3_col3\" class=\"data row3 col3\" >0.985801</td>\n",
" <td id=\"T_bf885_row3_col4\" class=\"data row3 col4\" >0.985901</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_bf885_level0_row4\" class=\"row_heading level0 row4\" >naive_bayes</th>\n",
" <td id=\"T_bf885_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_bf885_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_bf885_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_bf885_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_bf885_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_bf885_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
" <td id=\"T_bf885_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
" <td id=\"T_bf885_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
" <td id=\"T_bf885_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
" <td id=\"T_bf885_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
" <td id=\"T_bf885_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_bf885_level0_row6\" class=\"row_heading level0 row6\" >random_forest</th>\n",
" <td id=\"T_bf885_row6_col0\" class=\"data row6 col0\" >1.000000</td>\n",
" <td id=\"T_bf885_row6_col1\" class=\"data row6 col1\" >1.000000</td>\n",
" <td id=\"T_bf885_row6_col2\" class=\"data row6 col2\" >1.000000</td>\n",
" <td id=\"T_bf885_row6_col3\" class=\"data row6 col3\" >1.000000</td>\n",
" <td id=\"T_bf885_row6_col4\" class=\"data row6 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_bf885_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_bf885_row7_col0\" class=\"data row7 col0\" >0.733766</td>\n",
" <td id=\"T_bf885_row7_col1\" class=\"data row7 col1\" >0.539326</td>\n",
" <td id=\"T_bf885_row7_col2\" class=\"data row7 col2\" >0.653148</td>\n",
" <td id=\"T_bf885_row7_col3\" class=\"data row7 col3\" >0.363893</td>\n",
" <td id=\"T_bf885_row7_col4\" class=\"data row7 col4\" >0.380814</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x175bec42f60>"
]
},
"execution_count": 129,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 130,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
"execution_count": 132,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 0'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Predicted</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [Pregnancies, Predicted, Glucose, BloodPressure, SkinThickness, Insulin, BMI, DiabetesPedigreeFunction, Age, Outcome]\n",
"Index: []"
]
},
"execution_count": 132,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"Outcome\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Пример использования обученной модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
"execution_count": 141,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>555</th>\n",
" <td>7.0</td>\n",
" <td>124.0</td>\n",
" <td>70.0</td>\n",
" <td>33.0</td>\n",
" <td>215.0</td>\n",
" <td>25.5</td>\n",
" <td>0.161</td>\n",
" <td>37.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"555 7.0 124.0 70.0 33.0 215.0 25.5 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"555 0.161 37.0 0.0 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Glucose</th>\n",
" <th>Age</th>\n",
" <th>BloodPressure</th>\n",
" <th>Outcome</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>555</th>\n",
" <td>0.122742</td>\n",
" <td>0.322537</td>\n",
" <td>0.049921</td>\n",
" <td>-0.731437</td>\n",
" <td>-0.927636</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Glucose Age BloodPressure Outcome DiabetesPedigreeFunction\n",
"555 0.122742 0.322537 0.049921 -0.731437 -0.927636"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: 0 (proba: [0.99431769 0.00568231])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 0'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 555\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
"execution_count": 142,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\5_semester\\AIM\\rep\\AIM-PIbd-31-Razubaev-S-M\\.venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 5,\n",
" 'model__max_features': 'sqrt',\n",
" 'model__n_estimators': 10}"
]
},
"execution_count": 142,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"optimized_model_type = \"random_forest\"\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 50, 100],\n",
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
" \"model__max_depth\": [5, 7, 10],\n",
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
"execution_count": 143,
"metadata": {},
"outputs": [],
"source": [
"optimized_model = ensemble.RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"log2\",\n",
" n_estimators=10,\n",
")\n",
"\n",
"result = {}\n",
"\n",
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
"execution_count": 144,
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
"execution_count": 145,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_51672_row0_col0, #T_51672_row0_col1, #T_51672_row0_col2, #T_51672_row0_col3, #T_51672_row1_col0, #T_51672_row1_col1, #T_51672_row1_col2, #T_51672_row1_col3 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_51672_row0_col4, #T_51672_row0_col5, #T_51672_row0_col6, #T_51672_row0_col7, #T_51672_row1_col4, #T_51672_row1_col5, #T_51672_row1_col6, #T_51672_row1_col7 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_51672\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_51672_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_51672_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_51672_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_51672_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_51672_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_51672_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_51672_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_51672_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" <th class=\"blank col7\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_51672_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_51672_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_51672_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_51672_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_51672_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_51672_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_51672_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_51672_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_51672_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_51672_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_51672_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_51672_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_51672_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_51672_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_51672_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_51672_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_51672_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_51672_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x175be797c50>"
]
},
"execution_count": 145,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 146,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_abbc3_row0_col0, #T_abbc3_row0_col1, #T_abbc3_row1_col0, #T_abbc3_row1_col1 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_abbc3_row0_col2, #T_abbc3_row0_col3, #T_abbc3_row0_col4, #T_abbc3_row1_col2, #T_abbc3_row1_col3, #T_abbc3_row1_col4 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_abbc3\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_abbc3_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_abbc3_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_abbc3_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_abbc3_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_abbc3_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_abbc3_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_abbc3_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_abbc3_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_abbc3_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_abbc3_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_abbc3_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_abbc3_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_abbc3_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_abbc3_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_abbc3_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_abbc3_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_abbc3_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x175b9028fe0>"
]
},
"execution_count": 146,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 147,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2AAAAGxCAYAAAAEb9UHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABCWUlEQVR4nO3deXgV5d3/8c8kIQtZCUtCNIQAyqIIASoiKFUjgbpRbPlpQcOiVAURKGsVFFSC9EERRUCQ7XmkaF2oYsUiFpRFKii0atgpoJCAjUlIMOuZ3x/I0SMJEs6cOZOT9+u65iqZmdy5D+bi0+/c35kxTNM0BQAAAADwuSB/TwAAAAAA6goKMAAAAACwCQUYAAAAANiEAgwAAAAAbEIBBgAAAAA2oQADAAAAAJtQgAEAAACATSjAAAAAAMAmFGAAAAAAYBMKMAAAAACwCQUYAOAsH374oW655RYlJSXJMAytWrXK47hpmpoyZYqaNm2qiIgIpaena+/evR7n5OXlacCAAYqJiVFcXJyGDh2qoqIiGz8FAAA/cEq2UYABAM5SXFysDh06aO7cuVUenzlzpubMmaP58+dr69atioyMVEZGhkpKStznDBgwQF988YXWrl2r1atX68MPP9SwYcPs+ggAAHhwSrYZpmmaXn0SAEBAMwxDb775pvr27Svp9BXCpKQk/eEPf9DYsWMlSQUFBUpISNDSpUt1xx13KDs7W+3atdMnn3yiLl26SJLWrFmjX/3qV/rqq6+UlJTkr48DAIBfsy3EJ58IAOC1kpISlZWVWTaeaZoyDMNjX1hYmMLCwmo0zsGDB5WTk6P09HT3vtjYWHXt2lVbtmzRHXfcoS1btiguLs4dUJKUnp6uoKAgbd26Vb/+9a+9+zAAgFrHqbkm2ZttFGAA4EAlJSVKTYlSzvFKy8aMioo6q0/90Ucf1WOPPVajcXJyciRJCQkJHvsTEhLcx3JyctSkSROP4yEhIYqPj3efAwCoO5yca5K92UYBBgAOVFZWppzjlTq4PUUx0d7frlt40qXUzod05MgRxcTEuPdfyFVCAABqilz7AQUYADhYTHSQJUHlHi8mxiOoLkRiYqIkKTc3V02bNnXvz83NVceOHd3nHD9+3OP7KioqlJeX5/5+AEDd48Rck+zNNp6CCAAOVmm6LNuskpqaqsTERK1bt869r7CwUFu3blW3bt0kSd26dVN+fr62b9/uPueDDz6Qy+VS165dLZsLAKB2cWKuSfZmGytgAOBgLplyyfuH1dZ0jKKiIu3bt8/99cGDB7Vjxw7Fx8erWbNmGjVqlJ544gldcsklSk1N1eTJk5WUlOR+mlTbtm3Vu3dv3XvvvZo/f77Ky8s1YsQI3XHHHTwBEQDqMH/lmuScbKMAAwCcZdu2bbruuuvcX48ZM0aSlJmZqaVLl2r8+PEqLi7WsGHDlJ+frx49emjNmjUKDw93f8/LL7+sESNG6IYbblBQUJBuv/12zZkzx/bPAgCA5Jxs4z1gAOBAhYWFio2N1dHdF1t2s3JS669UUFBgSa88AAA1Qa79gBUwAHCwStNUpQXXyawYAwAAb5FrPIQDAAAAAGzDChgAOJg/b1YGAMBq5BoFGAA4mkumKut4UAEAAge5RgsiAAAAANiGFTAAcDBaNQAAgYRcYwUMAAAAAGzDChgAOBiP6wUABBJyjQIMABzN9f1mxTgAAPgbuUYLIgAAAADYhhUwAHCwSose12vFGAAAeItcowADAEerNE9vVowDAIC/kWu0IAIAAACAbVgBAwAH42ZlAEAgIdcowADA0VwyVCnDknEAAPA3co0WRAAAAACwDStgAOBgLvP0ZsU4AAD4G7nGChgAAAAA2IYVMABwsEqLeuWtGAMAAG+RaxRgAOBoBBUAIJCQa7QgAgAAAIBtWAEDAAdzmYZcpgWP67VgDAAAvEWuUYABgKPRqgEACCTkGi2IAAAAAGAbVsAAwMEqFaRKC66VVVowFwAAvEWuUYABgKOZFvXKm7W4Vx4AEDjINVoQAQAAAMA2rIABgINxszIAIJCQaxRgAOBolWaQKk0LeuVNCyYDAICXyDVaEAEAAADANqyAAYCDuWTIZcG1Mpdq8aVCAEDAINdYAQMAAAAA27ACBgAOxs3KAIBAQq5RgAGAo1l3s3LtbdUAAAQOco0WRAAAAACwDStgAOBgp29W9r7NwooxAADwFrlGAQYAjuZSkCrr+NOiAACBg1yjBREAAAAAbMMKGAA4GDcrAwACCblGAQYAjuZSUJ1/YSUAIHCQa7QgAgAAAIBtWAEDAAerNA1Vmha8sNKCMQAA8Ba5xgoYAAAAANiGFTAAcLBKix7XW1mLe+UBAIGDXKMAAwBHc5lBclnwtChXLX5aFAAgcJBrtCACAAAAgG1YAQMAB6NVAwAQSMg1CjAAcDSXrHnSk8v7qQAA4DVyjRZEAAAAALANK2AA4GAuBcllwbUyK8YAAMBb5BoFGAA4WqUZpEoLnhZlxRgAAHiLXKMFEQAAAABswwoYADiYS4ZcsuJmZe/HAADAW+QaBRgAOBqtGgCAQEKu0YIIAAAAALZhBQwAHMy6F1ZyvQ0A4H/kGitgAAAAAGAbVsDOg8vl0tGjRxUdHS3DqL03/AGwh2maOnnypJKSkhQU5N11LpdpyGVacLOyBWMgcJBrAGqCXLMWBdh5OHr0qJKTk/09DQC1zJEjR3TxxRd7NYbLolaN2vzCSliPXANwIcg1a1CAnYfo6GhJ0qFPmysmqvb+x4Zv/PrS9v6eAhymQuXaqL+5/+0AnIZcw7mQa/gpcs1aFGDn4Ux7RkxUkGKiCSp4CjHq+XsKcBrz9P9Y0drlMoPksuBRu1aMgcBBruFcyDWchVyzFAUYADhYpQxVWvCySSvGAADAW+QaT0EEAAAAANuwAgYADkarBgAgkJBrFGAA4GiVsqbNotL7qQAA4DVyjRZEAAAAALANK2AA4GC0agAAAgm5xgoYAAAAANiGFTAAcLBKM0iVFlzls2IMAAC8Ra6xAgYAjmbKkMuCzazhDc+VlZWaPHmyUlNTFRERoZYtW+rxxx+XaZo/zM00NWXKFDVt2lQRERFKT0/X3r17rf4rAAAEEHKNAgwAUIWnnnpK8+bN0/PPP6/s7Gw99dRTmjlzpp577jn3OTNnztScOXM0f/58bd26VZGRkcrIyFBJSYkfZw4AwNmclGu0IAKAg/mrVWPz5s267bbbdNNNN0mSmjdvrj//+c/65z//Ken0VcLZs2frkUce0W233SZJWr58uRISErRq1SrdcccdXs8ZABB4yDVWwADA0VymYdkmSYWFhR5baWlplT/36quv1rp167Rnzx5J0s6dO7Vx40b16dNHknTw4EHl5OQoPT3d/T2xsbHq2rWrtmzZ4uO/FQBAbUWusQIGAHVKcnKyx9ePPvqoHnvssbPOmzhxogoLC9WmTRsFBwersrJSTz75pAYMGCBJysnJkSQlJCR4fF9CQoL7GAAAvlYbc40CDAAcrFJBqrSgWeHMGEeOHFFMTIx7f1hYWJXnv/rqq3r55Ze1YsUKXXbZZdqxY4dGjRqlpKQkZWZmej0fAEDdRK5RgAGAo/24zcLbcSQpJibGI6iqM27cOE2cONHd896+fXsdOnRIWVlZyszMVGJioiQpNzdXTZs2dX9fbm6uOnbs6PV8AQCBiVzjHjAAQBVOnTqloCDPiAgODpbL5ZIkpaamKjExUevWrXMfLyws1NatW9WtWzdb5woAwM9xUq6xAgYADuZSkFwWXCur6Ri33HKLnnzySTVr1kyXXXaZPvvsMz399NMaMmSIJMkwDI0aNUpPPPGELrnkEqWmpmry5MlKSkpS3759vZ4vACAwkWsUYADgaJWmoUoLWjVqOsZzzz2nyZMn64EHHtDx48eVlJSk3//+95oyZYr7nPHjx6u4uFjDhg1Tfn6+evTooTVr1ig8PNzr+QIAAhO5Jhnmj1//jCoVFhYqNjZW3+5poZhoujbhKSOpo7+nAIepMMu1Xn9VQUHBefWlV+XMvzv3f9RPYVH1vJ5TaVG55l3zhldzQuAg13Au5Bp+ily
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Less\", \"More\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Регрессионная модель"
]
},
{
"cell_type": "code",
"execution_count": 148,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['Pregnancies', 'Glucose', 'BloodPressure', 'SkinThickness', 'Insulin',\n",
" 'BMI', 'DiabetesPedigreeFunction', 'Age', 'Outcome'],\n",
" dtype='object')\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>6</td>\n",
" <td>148</td>\n",
" <td>72</td>\n",
" <td>35</td>\n",
" <td>0</td>\n",
" <td>33.6</td>\n",
" <td>0.627</td>\n",
" <td>50</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>85</td>\n",
" <td>66</td>\n",
" <td>29</td>\n",
" <td>0</td>\n",
" <td>26.6</td>\n",
" <td>0.351</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>8</td>\n",
" <td>183</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>23.3</td>\n",
" <td>0.672</td>\n",
" <td>32</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>89</td>\n",
" <td>66</td>\n",
" <td>23</td>\n",
" <td>94</td>\n",
" <td>28.1</td>\n",
" <td>0.167</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>137</td>\n",
" <td>40</td>\n",
" <td>35</td>\n",
" <td>168</td>\n",
" <td>43.1</td>\n",
" <td>2.288</td>\n",
" <td>33</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>763</th>\n",
" <td>10</td>\n",
" <td>101</td>\n",
" <td>76</td>\n",
" <td>48</td>\n",
" <td>180</td>\n",
" <td>32.9</td>\n",
" <td>0.171</td>\n",
" <td>63</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>764</th>\n",
" <td>2</td>\n",
" <td>122</td>\n",
" <td>70</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" <td>36.8</td>\n",
" <td>0.340</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>765</th>\n",
" <td>5</td>\n",
" <td>121</td>\n",
" <td>72</td>\n",
" <td>23</td>\n",
" <td>112</td>\n",
" <td>26.2</td>\n",
" <td>0.245</td>\n",
" <td>30</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>766</th>\n",
" <td>1</td>\n",
" <td>126</td>\n",
" <td>60</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30.1</td>\n",
" <td>0.349</td>\n",
" <td>47</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>767</th>\n",
" <td>1</td>\n",
" <td>93</td>\n",
" <td>70</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" <td>30.4</td>\n",
" <td>0.315</td>\n",
" <td>23</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>768 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"0 6 148 72 35 0 33.6 \n",
"1 1 85 66 29 0 26.6 \n",
"2 8 183 64 0 0 23.3 \n",
"3 1 89 66 23 94 28.1 \n",
"4 0 137 40 35 168 43.1 \n",
".. ... ... ... ... ... ... \n",
"763 10 101 76 48 180 32.9 \n",
"764 2 122 70 27 0 36.8 \n",
"765 5 121 72 23 112 26.2 \n",
"766 1 126 60 0 0 30.1 \n",
"767 1 93 70 31 0 30.4 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"0 0.627 50 1 \n",
"1 0.351 31 0 \n",
"2 0.672 32 1 \n",
"3 0.167 21 0 \n",
"4 2.288 33 1 \n",
".. ... ... ... \n",
"763 0.171 63 0 \n",
"764 0.340 27 0 \n",
"765 0.245 30 0 \n",
"766 0.349 47 1 \n",
"767 0.315 23 0 \n",
"\n",
"[768 rows x 9 columns]"
]
},
"execution_count": 148,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import set_config\n",
"\n",
"random_state=9\n",
"set_config(transform_output=\"pandas\")\n",
"df = pd.read_csv(\".//scv//diabetes.csv\")\n",
"print(df.columns)\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Разделение набора данных на обучающую и тестовые выборки"
]
},
{
"cell_type": "code",
"execution_count": 150,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>60</th>\n",
" <td>2</td>\n",
" <td>84</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.304</td>\n",
" <td>21</td>\n",
" </tr>\n",
" <tr>\n",
" <th>618</th>\n",
" <td>9</td>\n",
" <td>112</td>\n",
" <td>82</td>\n",
" <td>24</td>\n",
" <td>0</td>\n",
" <td>28.2</td>\n",
" <td>1.282</td>\n",
" <td>50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>346</th>\n",
" <td>1</td>\n",
" <td>139</td>\n",
" <td>46</td>\n",
" <td>19</td>\n",
" <td>83</td>\n",
" <td>28.7</td>\n",
" <td>0.654</td>\n",
" <td>22</td>\n",
" </tr>\n",
" <tr>\n",
" <th>294</th>\n",
" <td>0</td>\n",
" <td>161</td>\n",
" <td>50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>21.9</td>\n",
" <td>0.254</td>\n",
" <td>65</td>\n",
" </tr>\n",
" <tr>\n",
" <th>231</th>\n",
" <td>6</td>\n",
" <td>134</td>\n",
" <td>80</td>\n",
" <td>37</td>\n",
" <td>370</td>\n",
" <td>46.2</td>\n",
" <td>0.238</td>\n",
" <td>46</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>71</th>\n",
" <td>5</td>\n",
" <td>139</td>\n",
" <td>64</td>\n",
" <td>35</td>\n",
" <td>140</td>\n",
" <td>28.6</td>\n",
" <td>0.411</td>\n",
" <td>26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>106</th>\n",
" <td>1</td>\n",
" <td>96</td>\n",
" <td>122</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>22.4</td>\n",
" <td>0.207</td>\n",
" <td>27</td>\n",
" </tr>\n",
" <tr>\n",
" <th>270</th>\n",
" <td>10</td>\n",
" <td>101</td>\n",
" <td>86</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" <td>45.6</td>\n",
" <td>1.136</td>\n",
" <td>38</td>\n",
" </tr>\n",
" <tr>\n",
" <th>435</th>\n",
" <td>0</td>\n",
" <td>141</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>42.4</td>\n",
" <td>0.205</td>\n",
" <td>29</td>\n",
" </tr>\n",
" <tr>\n",
" <th>102</th>\n",
" <td>0</td>\n",
" <td>125</td>\n",
" <td>96</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>22.5</td>\n",
" <td>0.262</td>\n",
" <td>21</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"60 2 84 0 0 0 0.0 \n",
"618 9 112 82 24 0 28.2 \n",
"346 1 139 46 19 83 28.7 \n",
"294 0 161 50 0 0 21.9 \n",
"231 6 134 80 37 370 46.2 \n",
".. ... ... ... ... ... ... \n",
"71 5 139 64 35 140 28.6 \n",
"106 1 96 122 0 0 22.4 \n",
"270 10 101 86 37 0 45.6 \n",
"435 0 141 0 0 0 42.4 \n",
"102 0 125 96 0 0 22.5 \n",
"\n",
" DiabetesPedigreeFunction Age \n",
"60 0.304 21 \n",
"618 1.282 50 \n",
"346 0.654 22 \n",
"294 0.254 65 \n",
"231 0.238 46 \n",
".. ... ... \n",
"71 0.411 26 \n",
"106 0.207 27 \n",
"270 1.136 38 \n",
"435 0.205 29 \n",
"102 0.262 21 \n",
"\n",
"[614 rows x 8 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>60</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>618</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>346</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>294</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>231</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>71</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>106</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>270</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>435</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>102</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Outcome\n",
"60 0\n",
"618 1\n",
"346 0\n",
"294 0\n",
"231 1\n",
".. ...\n",
"71 0\n",
"106 0\n",
"270 1\n",
"435 1\n",
"102 0\n",
"\n",
"[614 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>668</th>\n",
" <td>6</td>\n",
" <td>98</td>\n",
" <td>58</td>\n",
" <td>33</td>\n",
" <td>190</td>\n",
" <td>34.0</td>\n",
" <td>0.430</td>\n",
" <td>43</td>\n",
" </tr>\n",
" <tr>\n",
" <th>324</th>\n",
" <td>2</td>\n",
" <td>112</td>\n",
" <td>75</td>\n",
" <td>32</td>\n",
" <td>0</td>\n",
" <td>35.7</td>\n",
" <td>0.148</td>\n",
" <td>21</td>\n",
" </tr>\n",
" <tr>\n",
" <th>624</th>\n",
" <td>2</td>\n",
" <td>108</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30.8</td>\n",
" <td>0.158</td>\n",
" <td>21</td>\n",
" </tr>\n",
" <tr>\n",
" <th>690</th>\n",
" <td>8</td>\n",
" <td>107</td>\n",
" <td>80</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>24.6</td>\n",
" <td>0.856</td>\n",
" <td>34</td>\n",
" </tr>\n",
" <tr>\n",
" <th>473</th>\n",
" <td>7</td>\n",
" <td>136</td>\n",
" <td>90</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>29.9</td>\n",
" <td>0.210</td>\n",
" <td>50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>355</th>\n",
" <td>9</td>\n",
" <td>165</td>\n",
" <td>88</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30.4</td>\n",
" <td>0.302</td>\n",
" <td>49</td>\n",
" </tr>\n",
" <tr>\n",
" <th>534</th>\n",
" <td>1</td>\n",
" <td>77</td>\n",
" <td>56</td>\n",
" <td>30</td>\n",
" <td>56</td>\n",
" <td>33.3</td>\n",
" <td>1.251</td>\n",
" <td>24</td>\n",
" </tr>\n",
" <tr>\n",
" <th>344</th>\n",
" <td>8</td>\n",
" <td>95</td>\n",
" <td>72</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>36.8</td>\n",
" <td>0.485</td>\n",
" <td>57</td>\n",
" </tr>\n",
" <tr>\n",
" <th>296</th>\n",
" <td>2</td>\n",
" <td>146</td>\n",
" <td>70</td>\n",
" <td>38</td>\n",
" <td>360</td>\n",
" <td>28.0</td>\n",
" <td>0.337</td>\n",
" <td>29</td>\n",
" </tr>\n",
" <tr>\n",
" <th>462</th>\n",
" <td>8</td>\n",
" <td>74</td>\n",
" <td>70</td>\n",
" <td>40</td>\n",
" <td>49</td>\n",
" <td>35.3</td>\n",
" <td>0.705</td>\n",
" <td>39</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>154 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"668 6 98 58 33 190 34.0 \n",
"324 2 112 75 32 0 35.7 \n",
"624 2 108 64 0 0 30.8 \n",
"690 8 107 80 0 0 24.6 \n",
"473 7 136 90 0 0 29.9 \n",
".. ... ... ... ... ... ... \n",
"355 9 165 88 0 0 30.4 \n",
"534 1 77 56 30 56 33.3 \n",
"344 8 95 72 0 0 36.8 \n",
"296 2 146 70 38 360 28.0 \n",
"462 8 74 70 40 49 35.3 \n",
"\n",
" DiabetesPedigreeFunction Age \n",
"668 0.430 43 \n",
"324 0.148 21 \n",
"624 0.158 21 \n",
"690 0.856 34 \n",
"473 0.210 50 \n",
".. ... ... \n",
"355 0.302 49 \n",
"534 1.251 24 \n",
"344 0.485 57 \n",
"296 0.337 29 \n",
"462 0.705 39 \n",
"\n",
"[154 rows x 8 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>668</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>324</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>624</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>690</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>473</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>355</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>534</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>344</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>296</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>462</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>154 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Outcome\n",
"668 0\n",
"324 0\n",
"624 0\n",
"690 0\n",
"473 0\n",
".. ...\n",
"355 1\n",
"534 0\n",
"344 0\n",
"296 1\n",
"462 0\n",
"\n",
"[154 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_into_train_test(\n",
" df_input: DataFrame,\n",
" target_colname: str = \"above_average_close\",\n",
" frac_train: float = 0.8,\n",
" random_state: int = None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if not (0 < frac_train < 1):\n",
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
" \n",
" # Проверка наличия целевого признака\n",
" if target_colname not in df_input.columns:\n",
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
" \n",
" # Разделяем данные на признаки и целевую переменную\n",
" X = df_input.drop(columns=[target_colname]) # Признаки\n",
" y = df_input[[target_colname]] # Целевая переменная\n",
"\n",
" # Разделяем данные на обучающую и тестовую выборки\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y,\n",
" test_size=(1.0 - frac_train),\n",
" random_state=random_state\n",
" )\n",
" \n",
" return X_train, X_test, y_train, y_test\n",
"\n",
"# Применение функции для разделения данных\n",
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
" df, \n",
" target_colname=\"Outcome\", \n",
" frac_train=0.8, \n",
" random_state=42 # Убедитесь, что вы задали нужное значение random_state\n",
")\n",
"\n",
"# Для отображения результатов\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Определение перечня алгоритмов решения задачи аппроксимации (регрессии)"
]
},
{
"cell_type": "code",
"execution_count": 168,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
"\n",
"random_state = 9\n",
"\n",
"models = {\n",
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPRegressor(\n",
" activation=\"tanh\",\n",
" hidden_layer_sizes=(3,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 169,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: linear\n",
"Model: linear_poly\n",
"Model: linear_interact\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import math\n",
"from pandas import DataFrame\n",
"from sklearn import metrics\n",
"\n",
"for model_name in models.keys():\n",
" print(f\"Model: {model_name}\")\n",
"\n",
" fitted_model = models[model_name][\"model\"].fit(\n",
" X_train.values, y_train.values.ravel()\n",
" )\n",
" y_train_pred = fitted_model.predict(X_train.values)\n",
" y_test_pred = fitted_model.predict(X_test.values)\n",
" models[model_name][\"fitted\"] = fitted_model\n",
" models[model_name][\"train_preds\"] = y_train_pred\n",
" models[model_name][\"preds\"] = y_test_pred\n",
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_train, y_train_pred)\n",
" )\n",
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод результатов оценки"
]
},
{
"cell_type": "code",
"execution_count": 170,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_6040e_row0_col0, #T_6040e_row0_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row0_col2 {\n",
" background-color: #7e03a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row0_col3, #T_6040e_row7_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row1_col0, #T_6040e_row2_col0 {\n",
" background-color: #25ab82;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row1_col1 {\n",
" background-color: #24868e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row1_col2, #T_6040e_row2_col2 {\n",
" background-color: #a11b9b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row1_col3 {\n",
" background-color: #d5546e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row2_col1 {\n",
" background-color: #24878e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row2_col3 {\n",
" background-color: #d5536f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row3_col0, #T_6040e_row6_col0 {\n",
" background-color: #20a486;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row3_col1 {\n",
" background-color: #228d8d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row3_col2 {\n",
" background-color: #9a169f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row3_col3 {\n",
" background-color: #cf4c74;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row4_col0 {\n",
" background-color: #21a685;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row4_col1 {\n",
" background-color: #21918c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row4_col2 {\n",
" background-color: #a51f99;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row4_col3 {\n",
" background-color: #cc4977;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row5_col0 {\n",
" background-color: #25838e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row5_col1 {\n",
" background-color: #1f9f88;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row5_col2, #T_6040e_row7_col3 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row5_col3 {\n",
" background-color: #bf3984;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row6_col1 {\n",
" background-color: #1fa287;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row6_col2 {\n",
" background-color: #a31e9a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row6_col3 {\n",
" background-color: #bc3587;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6040e_row7_col0, #T_6040e_row7_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"</style>\n",
"<table id=\"T_6040e\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_6040e_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
" <th id=\"T_6040e_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
" <th id=\"T_6040e_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
" <th id=\"T_6040e_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_6040e_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
" <td id=\"T_6040e_row0_col0\" class=\"data row0 col0\" >0.240052</td>\n",
" <td id=\"T_6040e_row0_col1\" class=\"data row0 col1\" >0.405871</td>\n",
" <td id=\"T_6040e_row0_col2\" class=\"data row0 col2\" >0.559210</td>\n",
" <td id=\"T_6040e_row0_col3\" class=\"data row0 col3\" >0.282505</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_6040e_level0_row1\" class=\"row_heading level0 row1\" >linear</th>\n",
" <td id=\"T_6040e_row1_col0\" class=\"data row1 col0\" >0.396793</td>\n",
" <td id=\"T_6040e_row1_col1\" class=\"data row1 col1\" >0.413576</td>\n",
" <td id=\"T_6040e_row1_col2\" class=\"data row1 col2\" >0.590024</td>\n",
" <td id=\"T_6040e_row1_col3\" class=\"data row1 col3\" >0.255003</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_6040e_level0_row2\" class=\"row_heading level0 row2\" >ridge</th>\n",
" <td id=\"T_6040e_row2_col0\" class=\"data row2 col0\" >0.396822</td>\n",
" <td id=\"T_6040e_row2_col1\" class=\"data row2 col1\" >0.414236</td>\n",
" <td id=\"T_6040e_row2_col2\" class=\"data row2 col2\" >0.590431</td>\n",
" <td id=\"T_6040e_row2_col3\" class=\"data row2 col3\" >0.252623</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_6040e_level0_row3\" class=\"row_heading level0 row3\" >linear_poly</th>\n",
" <td id=\"T_6040e_row3_col0\" class=\"data row3 col0\" >0.370076</td>\n",
" <td id=\"T_6040e_row3_col1\" class=\"data row3 col1\" >0.422852</td>\n",
" <td id=\"T_6040e_row3_col2\" class=\"data row3 col2\" >0.584147</td>\n",
" <td id=\"T_6040e_row3_col3\" class=\"data row3 col3\" >0.221209</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_6040e_level0_row4\" class=\"row_heading level0 row4\" >linear_interact</th>\n",
" <td id=\"T_6040e_row4_col0\" class=\"data row4 col0\" >0.380128</td>\n",
" <td id=\"T_6040e_row4_col1\" class=\"data row4 col1\" >0.426815</td>\n",
" <td id=\"T_6040e_row4_col2\" class=\"data row4 col2\" >0.593532</td>\n",
" <td id=\"T_6040e_row4_col3\" class=\"data row4 col3\" >0.206543</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_6040e_level0_row5\" class=\"row_heading level0 row5\" >decision_tree</th>\n",
" <td id=\"T_6040e_row5_col0\" class=\"data row5 col0\" >0.249880</td>\n",
" <td id=\"T_6040e_row5_col1\" class=\"data row5 col1\" >0.445708</td>\n",
" <td id=\"T_6040e_row5_col2\" class=\"data row5 col2\" >0.520376</td>\n",
" <td id=\"T_6040e_row5_col3\" class=\"data row5 col3\" >0.134743</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_6040e_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
" <td id=\"T_6040e_row6_col0\" class=\"data row6 col0\" >0.373319</td>\n",
" <td id=\"T_6040e_row6_col1\" class=\"data row6 col1\" >0.450285</td>\n",
" <td id=\"T_6040e_row6_col2\" class=\"data row6 col2\" >0.592157</td>\n",
" <td id=\"T_6040e_row6_col3\" class=\"data row6 col3\" >0.116883</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_6040e_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_6040e_row7_col0\" class=\"data row7 col0\" >0.623529</td>\n",
" <td id=\"T_6040e_row7_col1\" class=\"data row7 col1\" >0.544323</td>\n",
" <td id=\"T_6040e_row7_col2\" class=\"data row7 col2\" >0.658689</td>\n",
" <td id=\"T_6040e_row7_col3\" class=\"data row7 col3\" >-0.290498</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x175be7f0a70>"
]
},
"execution_count": 170,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"Вывод реального и \"спрогнозированного\" результата для обучающей и тестовой выборок\n",
"\n",
"Получение лучшей модели\n"
]
},
{
"cell_type": "code",
"execution_count": 171,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'random_forest'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод для обучающей выборки"
]
},
{
"cell_type": "code",
"execution_count": 173,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" <th>DiabetPred</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>60</th>\n",
" <td>2</td>\n",
" <td>84</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.304</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" <td>0.001849</td>\n",
" </tr>\n",
" <tr>\n",
" <th>618</th>\n",
" <td>9</td>\n",
" <td>112</td>\n",
" <td>82</td>\n",
" <td>24</td>\n",
" <td>0</td>\n",
" <td>28.2</td>\n",
" <td>1.282</td>\n",
" <td>50</td>\n",
" <td>1</td>\n",
" <td>0.758997</td>\n",
" </tr>\n",
" <tr>\n",
" <th>346</th>\n",
" <td>1</td>\n",
" <td>139</td>\n",
" <td>46</td>\n",
" <td>19</td>\n",
" <td>83</td>\n",
" <td>28.7</td>\n",
" <td>0.654</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" <td>0.149231</td>\n",
" </tr>\n",
" <tr>\n",
" <th>294</th>\n",
" <td>0</td>\n",
" <td>161</td>\n",
" <td>50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>21.9</td>\n",
" <td>0.254</td>\n",
" <td>65</td>\n",
" <td>0</td>\n",
" <td>0.239564</td>\n",
" </tr>\n",
" <tr>\n",
" <th>231</th>\n",
" <td>6</td>\n",
" <td>134</td>\n",
" <td>80</td>\n",
" <td>37</td>\n",
" <td>370</td>\n",
" <td>46.2</td>\n",
" <td>0.238</td>\n",
" <td>46</td>\n",
" <td>1</td>\n",
" <td>0.773890</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"60 2 84 0 0 0 0.0 \n",
"618 9 112 82 24 0 28.2 \n",
"346 1 139 46 19 83 28.7 \n",
"294 0 161 50 0 0 21.9 \n",
"231 6 134 80 37 370 46.2 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome DiabetPred \n",
"60 0.304 21 0 0.001849 \n",
"618 1.282 50 1 0.758997 \n",
"346 0.654 22 0 0.149231 \n",
"294 0.254 65 0 0.239564 \n",
"231 0.238 46 1 0.773890 "
]
},
"execution_count": 173,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.concat(\n",
" [\n",
" X_train,\n",
" y_train,\n",
" pd.Series(\n",
" models[best_model][\"train_preds\"],\n",
" index=y_train.index,\n",
" name=\"DiabetPred\",\n",
" ),\n",
" ],\n",
" axis=1,\n",
").head(5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод для тестовой выборки"
]
},
{
"cell_type": "code",
"execution_count": 174,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" <th>DiabetPred</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>668</th>\n",
" <td>6</td>\n",
" <td>98</td>\n",
" <td>58</td>\n",
" <td>33</td>\n",
" <td>190</td>\n",
" <td>34.0</td>\n",
" <td>0.430</td>\n",
" <td>43</td>\n",
" <td>0</td>\n",
" <td>0.516537</td>\n",
" </tr>\n",
" <tr>\n",
" <th>324</th>\n",
" <td>2</td>\n",
" <td>112</td>\n",
" <td>75</td>\n",
" <td>32</td>\n",
" <td>0</td>\n",
" <td>35.7</td>\n",
" <td>0.148</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" <td>0.205507</td>\n",
" </tr>\n",
" <tr>\n",
" <th>624</th>\n",
" <td>2</td>\n",
" <td>108</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30.8</td>\n",
" <td>0.158</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" <td>0.047710</td>\n",
" </tr>\n",
" <tr>\n",
" <th>690</th>\n",
" <td>8</td>\n",
" <td>107</td>\n",
" <td>80</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>24.6</td>\n",
" <td>0.856</td>\n",
" <td>34</td>\n",
" <td>0</td>\n",
" <td>0.128867</td>\n",
" </tr>\n",
" <tr>\n",
" <th>473</th>\n",
" <td>7</td>\n",
" <td>136</td>\n",
" <td>90</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>29.9</td>\n",
" <td>0.210</td>\n",
" <td>50</td>\n",
" <td>0</td>\n",
" <td>0.438512</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"668 6 98 58 33 190 34.0 \n",
"324 2 112 75 32 0 35.7 \n",
"624 2 108 64 0 0 30.8 \n",
"690 8 107 80 0 0 24.6 \n",
"473 7 136 90 0 0 29.9 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome DiabetPred \n",
"668 0.430 43 0 0.516537 \n",
"324 0.148 21 0 0.205507 \n",
"624 0.158 21 0 0.047710 \n",
"690 0.856 34 0 0.128867 \n",
"473 0.210 50 0 0.438512 "
]
},
"execution_count": 174,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.concat(\n",
" [\n",
" X_test,\n",
" y_test,\n",
" pd.Series(\n",
" models[best_model][\"preds\"],\n",
" index=y_test.index,\n",
" name=\"DiabetPred\",\n",
" ),\n",
" ],\n",
" axis=1,\n",
").head(5)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}