AIM-PIbd-31-Razubaev-S-M/Lab4/lab4.ipynb

7581 lines
844 KiB
Plaintext
Raw Normal View History

2024-11-08 22:14:23 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Лабораторная 4"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Информация о диабете индейцев Пима"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 267,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['Pregnancies', 'Glucose', 'BloodPressure', 'SkinThickness', 'Insulin',\n",
" 'BMI', 'DiabetesPedigreeFunction', 'Age', 'Outcome'],\n",
" dtype='object')\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>6</td>\n",
" <td>148</td>\n",
" <td>72</td>\n",
" <td>35</td>\n",
" <td>0</td>\n",
" <td>33.6</td>\n",
" <td>0.627</td>\n",
" <td>50</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>85</td>\n",
" <td>66</td>\n",
" <td>29</td>\n",
" <td>0</td>\n",
" <td>26.6</td>\n",
" <td>0.351</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>8</td>\n",
" <td>183</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>23.3</td>\n",
" <td>0.672</td>\n",
" <td>32</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>89</td>\n",
" <td>66</td>\n",
" <td>23</td>\n",
" <td>94</td>\n",
" <td>28.1</td>\n",
" <td>0.167</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>137</td>\n",
" <td>40</td>\n",
" <td>35</td>\n",
" <td>168</td>\n",
" <td>43.1</td>\n",
" <td>2.288</td>\n",
" <td>33</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>763</th>\n",
" <td>10</td>\n",
" <td>101</td>\n",
" <td>76</td>\n",
" <td>48</td>\n",
" <td>180</td>\n",
" <td>32.9</td>\n",
" <td>0.171</td>\n",
" <td>63</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>764</th>\n",
" <td>2</td>\n",
" <td>122</td>\n",
" <td>70</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" <td>36.8</td>\n",
" <td>0.340</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>765</th>\n",
" <td>5</td>\n",
" <td>121</td>\n",
" <td>72</td>\n",
" <td>23</td>\n",
" <td>112</td>\n",
" <td>26.2</td>\n",
" <td>0.245</td>\n",
" <td>30</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>766</th>\n",
" <td>1</td>\n",
" <td>126</td>\n",
" <td>60</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30.1</td>\n",
" <td>0.349</td>\n",
" <td>47</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>767</th>\n",
" <td>1</td>\n",
" <td>93</td>\n",
" <td>70</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" <td>30.4</td>\n",
" <td>0.315</td>\n",
" <td>23</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>768 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"0 6 148 72 35 0 33.6 \n",
"1 1 85 66 29 0 26.6 \n",
"2 8 183 64 0 0 23.3 \n",
"3 1 89 66 23 94 28.1 \n",
"4 0 137 40 35 168 43.1 \n",
".. ... ... ... ... ... ... \n",
"763 10 101 76 48 180 32.9 \n",
"764 2 122 70 27 0 36.8 \n",
"765 5 121 72 23 112 26.2 \n",
"766 1 126 60 0 0 30.1 \n",
"767 1 93 70 31 0 30.4 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"0 0.627 50 1 \n",
"1 0.351 31 0 \n",
"2 0.672 32 1 \n",
"3 0.167 21 0 \n",
"4 2.288 33 1 \n",
".. ... ... ... \n",
"763 0.171 63 0 \n",
"764 0.340 27 0 \n",
"765 0.245 30 0 \n",
"766 0.349 47 1 \n",
"767 0.315 23 0 \n",
"\n",
"[768 rows x 9 columns]"
]
},
2024-11-09 10:46:39 +04:00
"execution_count": 267,
2024-11-08 22:14:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"df = pd.read_csv(\".//scv//diabetes.csv\")\n",
"print(df.columns)\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование выборок"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 268,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>196</th>\n",
" <td>1</td>\n",
" <td>105</td>\n",
" <td>58</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>24.3</td>\n",
" <td>0.187</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69</th>\n",
" <td>4</td>\n",
" <td>146</td>\n",
" <td>85</td>\n",
" <td>27</td>\n",
" <td>100</td>\n",
" <td>28.9</td>\n",
" <td>0.189</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>494</th>\n",
" <td>3</td>\n",
" <td>80</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.174</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>463</th>\n",
" <td>5</td>\n",
" <td>88</td>\n",
" <td>78</td>\n",
" <td>30</td>\n",
" <td>0</td>\n",
" <td>27.6</td>\n",
" <td>0.258</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>653</th>\n",
" <td>2</td>\n",
" <td>120</td>\n",
" <td>54</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>26.8</td>\n",
" <td>0.455</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>322</th>\n",
" <td>0</td>\n",
" <td>124</td>\n",
" <td>70</td>\n",
" <td>20</td>\n",
" <td>0</td>\n",
" <td>27.4</td>\n",
" <td>0.254</td>\n",
" <td>36</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>109</th>\n",
" <td>0</td>\n",
" <td>95</td>\n",
" <td>85</td>\n",
" <td>25</td>\n",
" <td>36</td>\n",
" <td>37.4</td>\n",
" <td>0.247</td>\n",
" <td>24</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>1</td>\n",
" <td>97</td>\n",
" <td>66</td>\n",
" <td>15</td>\n",
" <td>140</td>\n",
" <td>23.2</td>\n",
" <td>0.487</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>651</th>\n",
" <td>1</td>\n",
" <td>117</td>\n",
" <td>60</td>\n",
" <td>23</td>\n",
" <td>106</td>\n",
" <td>33.8</td>\n",
" <td>0.466</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>197</th>\n",
" <td>3</td>\n",
" <td>107</td>\n",
" <td>62</td>\n",
" <td>13</td>\n",
" <td>48</td>\n",
" <td>22.9</td>\n",
" <td>0.678</td>\n",
" <td>23</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"196 1 105 58 0 0 24.3 \n",
"69 4 146 85 27 100 28.9 \n",
"494 3 80 0 0 0 0.0 \n",
"463 5 88 78 30 0 27.6 \n",
"653 2 120 54 0 0 26.8 \n",
".. ... ... ... ... ... ... \n",
"322 0 124 70 20 0 27.4 \n",
"109 0 95 85 25 36 37.4 \n",
"27 1 97 66 15 140 23.2 \n",
"651 1 117 60 23 106 33.8 \n",
"197 3 107 62 13 48 22.9 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"196 0.187 21 0 \n",
"69 0.189 27 0 \n",
"494 0.174 22 0 \n",
"463 0.258 37 0 \n",
"653 0.455 27 0 \n",
".. ... ... ... \n",
"322 0.254 36 1 \n",
"109 0.247 24 1 \n",
"27 0.487 22 0 \n",
"651 0.466 27 0 \n",
"197 0.678 23 1 \n",
"\n",
"[614 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>196</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>494</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>463</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>653</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>322</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>109</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>651</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>197</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Outcome\n",
"196 0\n",
"69 0\n",
"494 0\n",
"463 0\n",
"653 0\n",
".. ...\n",
"322 1\n",
"109 1\n",
"27 0\n",
"651 0\n",
"197 1\n",
"\n",
"[614 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>669</th>\n",
" <td>9</td>\n",
" <td>154</td>\n",
" <td>78</td>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>30.9</td>\n",
" <td>0.164</td>\n",
" <td>45</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>379</th>\n",
" <td>0</td>\n",
" <td>93</td>\n",
" <td>100</td>\n",
" <td>39</td>\n",
" <td>72</td>\n",
" <td>43.4</td>\n",
" <td>1.021</td>\n",
" <td>35</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>640</th>\n",
" <td>0</td>\n",
" <td>102</td>\n",
" <td>86</td>\n",
" <td>17</td>\n",
" <td>105</td>\n",
" <td>29.3</td>\n",
" <td>0.695</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>658</th>\n",
" <td>11</td>\n",
" <td>127</td>\n",
" <td>106</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>39.0</td>\n",
" <td>0.190</td>\n",
" <td>51</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>304</th>\n",
" <td>3</td>\n",
" <td>150</td>\n",
" <td>76</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>21.0</td>\n",
" <td>0.207</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203</th>\n",
" <td>2</td>\n",
" <td>99</td>\n",
" <td>70</td>\n",
" <td>16</td>\n",
" <td>44</td>\n",
" <td>20.4</td>\n",
" <td>0.235</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>605</th>\n",
" <td>1</td>\n",
" <td>124</td>\n",
" <td>60</td>\n",
" <td>32</td>\n",
" <td>0</td>\n",
" <td>35.8</td>\n",
" <td>0.514</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>561</th>\n",
" <td>0</td>\n",
" <td>198</td>\n",
" <td>66</td>\n",
" <td>32</td>\n",
" <td>274</td>\n",
" <td>41.3</td>\n",
" <td>0.502</td>\n",
" <td>28</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>280</th>\n",
" <td>0</td>\n",
" <td>146</td>\n",
" <td>70</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>37.9</td>\n",
" <td>0.334</td>\n",
" <td>28</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>103</th>\n",
" <td>1</td>\n",
" <td>81</td>\n",
" <td>72</td>\n",
" <td>18</td>\n",
" <td>40</td>\n",
" <td>26.6</td>\n",
" <td>0.283</td>\n",
" <td>24</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>154 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"669 9 154 78 30 100 30.9 \n",
"379 0 93 100 39 72 43.4 \n",
"640 0 102 86 17 105 29.3 \n",
"658 11 127 106 0 0 39.0 \n",
"304 3 150 76 0 0 21.0 \n",
".. ... ... ... ... ... ... \n",
"203 2 99 70 16 44 20.4 \n",
"605 1 124 60 32 0 35.8 \n",
"561 0 198 66 32 274 41.3 \n",
"280 0 146 70 0 0 37.9 \n",
"103 1 81 72 18 40 26.6 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"669 0.164 45 0 \n",
"379 1.021 35 0 \n",
"640 0.695 27 0 \n",
"658 0.190 51 0 \n",
"304 0.207 37 0 \n",
".. ... ... ... \n",
"203 0.235 27 0 \n",
"605 0.514 21 0 \n",
"561 0.502 28 1 \n",
"280 0.334 28 1 \n",
"103 0.283 24 0 \n",
"\n",
"[154 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>669</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>379</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>640</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>658</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>304</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>605</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>561</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>280</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>103</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>154 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Outcome\n",
"669 0\n",
"379 0\n",
"640 0\n",
"658 0\n",
"304 0\n",
".. ...\n",
"203 0\n",
"605 0\n",
"561 1\n",
"280 1\n",
"103 0\n",
"\n",
"[154 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" X = df_input # Contains all columns.\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
" # Split original dataframe into train and temp dataframes.\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
" # Split the temp dataframe into val and test dataframes.\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"Outcome\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=9\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
2024-11-09 10:46:39 +04:00
{
"cell_type": "code",
"execution_count": 269,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Пропущенные значения по столбцам:\n",
"Pregnancies 0\n",
"Glucose 0\n",
"BloodPressure 0\n",
"SkinThickness 0\n",
"Insulin 0\n",
"BMI 0\n",
"DiabetesPedigreeFunction 0\n",
"Age 0\n",
"Outcome 0\n",
"dtype: int64\n",
"\n",
"Статистический обзор данных:\n",
" Pregnancies Glucose BloodPressure SkinThickness Insulin \\\n",
"count 768.000000 768.000000 768.000000 768.000000 768.000000 \n",
"mean 3.845052 120.894531 69.105469 20.536458 79.799479 \n",
"std 3.369578 31.972618 19.355807 15.952218 115.244002 \n",
"min 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
"25% 1.000000 99.000000 62.000000 0.000000 0.000000 \n",
"50% 3.000000 117.000000 72.000000 23.000000 30.500000 \n",
"75% 6.000000 140.250000 80.000000 32.000000 127.250000 \n",
"max 17.000000 199.000000 122.000000 99.000000 846.000000 \n",
"\n",
" BMI DiabetesPedigreeFunction Age Outcome \n",
"count 768.000000 768.000000 768.000000 768.000000 \n",
"mean 31.992578 0.471876 33.240885 0.348958 \n",
"std 7.884160 0.331329 11.760232 0.476951 \n",
"min 0.000000 0.078000 21.000000 0.000000 \n",
"25% 27.300000 0.243750 24.000000 0.000000 \n",
"50% 32.000000 0.372500 29.000000 0.000000 \n",
"75% 36.600000 0.626250 41.000000 1.000000 \n",
"max 67.100000 2.420000 81.000000 1.000000 \n"
]
}
],
"source": [
"null_values = df.isnull().sum()\n",
"print(\"Пропущенные значения по столбцам:\")\n",
"print(null_values)\n",
"\n",
"stat_summary = df.describe()\n",
"print(\"\\nСтатистический обзор данных:\")\n",
"print(stat_summary)"
]
},
{
"cell_type": "code",
"execution_count": 270,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Выбросы в датасете:\n",
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"4 0 137 40 35 168 43.1 \n",
"12 10 139 80 0 0 27.1 \n",
"39 4 111 72 47 207 37.1 \n",
"45 0 180 66 39 0 42.0 \n",
"58 0 146 82 0 0 40.5 \n",
"100 1 163 72 0 0 39.0 \n",
"147 2 106 64 35 119 30.5 \n",
"187 1 128 98 41 58 32.0 \n",
"218 5 85 74 22 0 29.0 \n",
"228 4 197 70 39 744 36.7 \n",
"243 6 119 50 22 176 27.1 \n",
"245 9 184 85 15 0 30.0 \n",
"259 11 155 76 28 150 33.3 \n",
"292 2 128 78 37 182 43.3 \n",
"308 0 128 68 19 180 30.5 \n",
"330 8 118 72 19 0 23.1 \n",
"370 3 173 82 48 465 38.4 \n",
"371 0 118 64 23 89 0.0 \n",
"383 1 90 62 18 59 25.1 \n",
"395 2 127 58 24 275 27.7 \n",
"445 0 180 78 63 14 59.4 \n",
"534 1 77 56 30 56 33.3 \n",
"593 2 82 52 22 115 28.5 \n",
"606 1 181 78 42 293 40.0 \n",
"618 9 112 82 24 0 28.2 \n",
"621 2 92 76 20 0 24.2 \n",
"622 6 183 94 0 0 40.8 \n",
"659 3 80 82 31 70 34.2 \n",
"661 1 199 76 43 0 42.9 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"4 2.288 33 1 \n",
"12 1.441 57 0 \n",
"39 1.390 56 1 \n",
"45 1.893 25 1 \n",
"58 1.781 44 0 \n",
"100 1.222 33 1 \n",
"147 1.400 34 0 \n",
"187 1.321 33 1 \n",
"218 1.224 32 1 \n",
"228 2.329 31 0 \n",
"243 1.318 33 1 \n",
"245 1.213 49 1 \n",
"259 1.353 51 1 \n",
"292 1.224 31 1 \n",
"308 1.391 25 1 \n",
"330 1.476 46 0 \n",
"370 2.137 25 1 \n",
"371 1.731 21 0 \n",
"383 1.268 25 0 \n",
"395 1.600 25 0 \n",
"445 2.420 25 1 \n",
"534 1.251 24 0 \n",
"593 1.699 25 0 \n",
"606 1.258 22 1 \n",
"618 1.282 50 1 \n",
"621 1.698 28 0 \n",
"622 1.461 45 0 \n",
"659 1.292 27 1 \n",
"661 1.394 22 1 \n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0kAAAIjCAYAAADWYVDIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAC5fUlEQVR4nOzdeXwU5eEG8Gc3dwK5OJKgGAKegILgQUBAKxbUQr1K64k3UtoqntAWEa3Gs+qvRRRUsCLiUS8spSqgeAS0IFYEFGNA1AQkIQkk5CA7vz/CrrubmZ3rnWv3+X4+fJQ93nmvmd2XnXnGJ0mSBCIiIiIiIgIA+J2uABERERERkZtwkURERERERBSGiyQiIiIiIqIwXCQRERERERGF4SKJiIiIiIgoDBdJREREREREYbhIIiIiIiIiCsNFEhERERERURgukoiIiIiIiMJwkURERERERBSGiyQissTLL78Mn88n+2fgwIFOV4+IiIhIUbLTFSCi+PbHP/4RxxxzTOjvd999t4O1ISIiIlLHRRIRWeqMM87AqaeeGvr7k08+id27dztXISIiIiIVPN2OiCzR2toKAPD71Q8zCxcuhM/nw7Zt20KPBQIBHHfccfD5fFi4cGHo8f/973+4/PLL0bdvX6Snp6OwsBBXXnklampqIsq84447ZE/1S07+6d+GTj31VAwcOBDr1q3D8OHDkZGRgZKSEjz++OOd2nL77bdj6NChyMnJQVZWFkaOHIlVq1ZFvG7btm2h7bz22msRzzU3NyMvLw8+nw8PPvhgp3r27NkTbW1tEe95/vnnQ+WFLyxff/11nH322ejVqxfS0tLQr18/3HXXXWhvb1ft6+D2tmzZgokTJyI7OxvdunXD9ddfj+bm5ojXLliwAD/72c/Qs2dPpKWloX///pg7d65suf/+978xevRodO3aFdnZ2TjxxBOxePHiiNesXbsWZ511FvLy8pCVlYXjjjsOjz76aMRrtmzZggsuuAD5+flIT0/HCSecgDfeeCPiNXrmy+WXXx4x/nl5eTj11FPx/vvvR5SptU+Dcybagw8+2KlOffr0weWXXx7xupdeegk+nw99+vSJeHzXrl246qqrcNhhhyEpKSlU3y5dunTaVrQ+ffoontrq8/kiXnvgwAHcdddd6NevH9LS0tCnTx/88Y9/REtLS6dytYxp+JyPtd1AIIBHHnkEAwYMQHp6OgoKCjB58mTs2bNHU/ui+/Hdd9+Fz+fDu+++G3rs1FNPjfgHGQD45JNPZOsDAIsWLcJJJ52EzMxM5OXlYdSoUXjrrbdC24zVp8HxC7Y/fM7t3bsXQ4cORUlJCaqqqhRfBwBTp06Fz+fr1D4ich5/SSIiSwQXSWlpaYbe/+yzz+Lzzz/v9Pjbb7+Nb775BldccQUKCwvxxRdfYN68efjiiy+wZs2aTl+G5s6dG/FFM3rRtmfPHpx11lmYOHEiLrzwQrz44ouYMmUKUlNTceWVVwIAGhoa8OSTT+LCCy/ENddcg7179+Kpp57C2LFj8fHHH2Pw4MERZaanp2PBggU455xzQo+98sornRYh4fbu3Ys333wT5557buixBQsWID09vdP7Fi5ciC5duuDGG29Ely5dsHLlStx+++1oaGjAAw88oLiNcBMnTkSfPn1QVlaGNWvW4P/+7/+wZ88e/OMf/4jouwEDBmDChAlITk7G0qVL8dvf/haBQABTp06NqM+VV16JAQMGYMaMGcjNzcWnn36K5cuX46KLLgLQMW6/+MUvUFRUhOuvvx6FhYXYvHkz3nzzTVx//fUAgC+++AIjRozAIYccgunTpyMrKwsvvvgizjnnHPzzn/+M6JtoSvMFALp3746HH34YAPDdd9/h0UcfxVlnnYUdO3YgNzdXWJ+qOXDgAP70pz/JPjdp0iS88847+P3vf49BgwYhKSkJ8+bNw/r16zWVPXjwYNx0000Rj/3jH//A22+/HfHY1VdfjWeeeQYXXHABbrrpJqxduxZlZWXYvHkzXn311dDrtIxpuGuvvRYjR44E0DHXw8sCgMmTJ2PhwoW44oor8Ic//AGVlZX4+9//jk8//RQffvghUlJSNLVTr9tuu0328dmzZ+OOO+7A8OHDceeddyI1NRVr167FypUr8fOf/xyPPPII9u3bBwDYvHkz7rnnnohTh5UWr21tbTj//PPx7bff4sMPP0RRUZFi3b7++mvMnz/fZAuJyDISEZEFHnnkEQmA9Nlnn0U8Pnr0aGnAgAERjy1YsEACIFVWVkqSJEnNzc3SYYcdJp155pkSAGnBggWh1zY1NXXa1vPPPy8BkFavXh16bNasWRIA6ccff1Ss4+jRoyUA0kMPPRR6rKWlRRo8eLDUs2dPqbW1VZIkSTpw4IDU0tIS8d49e/ZIBQUF0pVXXhl6rLKyUgIgXXjhhVJycrJUXV0deu7000+XLrroIgmA9MADD3Sq54UXXij94he/CD2+fft2ye/3SxdeeGGndsj1weTJk6XMzEypublZsb3h25swYULE47/97W87jZfcdsaOHSv17ds39Pe6ujqpa9eu0sknnyzt378/4rWBQECSpI7+KykpkYqLi6U9e/bIvkaSOvro2GOPjWhDIBCQhg8fLh1xxBGhx/TMl0mTJknFxcUR25w3b54EQPr4449jtlWuT+XmryRJ0gMPPBBRJ0mSpOLiYmnSpEmhvz/22GNSWlqadNppp0XUaf/+/ZLf75cmT54cUeakSZOkrKysTtuKVlxcLJ199tmdHp86daoU/jG/YcMGCYB09dVXR7zu5ptvlgBIK1eulCRJ25gGbd26VQIgPfPMM6HHgnMs6P3335cASM8991zEe5cvXy77eLSSkhLpsssui3hs1apVEgBp1apVocdGjx4tjR49OvT3ZcuWSQCkcePGRdRn69atkt/vl84991ypvb09ZvuUthUU3OcXLFggBQIB6eKLL5YyMzOltWvXKr4uaOLEidLAgQOl3r17R8wTInIHnm5HRJYInv7Wo0cP3e+dM2cOampqMGvWrE7PZWRkhP6/ubkZu3fvxrBhwwBA87+6h0tOTsbkyZNDf09NTcXkyZOxa9curFu3DgCQlJSE1NRUAB2nDdXW1uLAgQM44YQTZLc5ZMgQDBgwAM8++ywAYPv27Vi1alXMU2quvPJKLF++HNXV1QCAZ555BqWlpTjyyCM7vTa8D/bu3Yvdu3dj5MiRaGpqwpYtWzS1O/yXIAD4/e9/DwBYtmyZ7Hbq6+uxe/dujB49Gt988w3q6+sBdPxCtHfvXkyfPh3p6ekRZQZ/1fv0009RWVmJG264IfTLTfRramtrsXLlSkycODHUpt27d6OmpgZjx47F1q1b8f3338u2JdZ8ATrGLFjehg0b8I9//ANFRUURgSJ6+rS9vT1UXvBPU1OT7LaDmpqacOedd+J3v/sdDjvssIjnGhsbEQgE0K1bt5hlmBUc2xtvvDHi8eAvUP/6178AaBvTIC2/GL/00kvIycnBGWecEdFnQ4cORZcuXTqdthqtZ8+e+O677zS08CeSJGHGjBk4//zzcfLJJ0c899prryEQCOD222/v9Muy3Gl5Wt1yyy147rnn8OKLL+Kkk06K+dp169bhpZdeQllZmaZTkonIftwzicgS27dvR3Jysu5FUn19Pe655x7ceOONKCgo6PR8bW0trr/+ehQUFCAjIwM9evRASUlJ6L169erVC1lZWRGPBRcm4deXPPPMMzjuuOOQnp6Obt26oUePHvjXv/6luM0rrrgCCxYsANBx6tLw4cNxxBFHKNZj8ODBGDhwIP7xj39AkqTQqUlyvvjiC5x77rnIyclBdnY2evTogUsuuQSA9j6Irku/fv3g9/sj2vzhhx9izJgxyMrKQm5uLnr06IE//vGPEdupqKgAgJix7lpe8/XXX0OSJMycORM9evSI+BNc/OzatavT+9TmCwDs2LEjVNbxxx+PiooK/POf/4w4ZUpPn27ZskWxjkr++te/orm5OdR/4bp164YjjjgCTz75JN566y3s2rULu3fvlr1OyIzt27fD7/fj8MMPj3i8sLAQubm52L5
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"Q1 = df[\"DiabetesPedigreeFunction\"].quantile(0.25)\n",
"Q3 = df[\"DiabetesPedigreeFunction\"].quantile(0.75)\n",
"\n",
"IQR = Q3 - Q1\n",
"\n",
"threshold = 1.5 * IQR\n",
"lower_bound = Q1 - threshold\n",
"upper_bound = Q3 + threshold\n",
"\n",
"outliers = (df[\"DiabetesPedigreeFunction\"] < lower_bound) | (df[\"DiabetesPedigreeFunction\"] > upper_bound)\n",
"\n",
"# Вывод выбросов\n",
"print(\"Выбросы в датасете:\")\n",
"print(df[outliers])\n",
"\n",
"# Заменяем выбросы на медианные значения\n",
"median_score = df[\"DiabetesPedigreeFunction\"].median()\n",
"df.loc[outliers, \"DiabetesPedigreeFunction\"] = median_score\n",
"\n",
"# Визуализация данных после обработки\n",
"plt.figure(figsize=(10, 6))\n",
"plt.scatter(df['DiabetesPedigreeFunction'], df['Age'])\n",
"plt.xlabel('Функция родословной диабета')\n",
"plt.ylabel('Возраст')\n",
"plt.title('Диаграмма рассеивания после чистки')\n",
"plt.show()"
]
},
2024-11-08 22:14:23 +04:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Классификация данных"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 271,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [],
"source": [
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"\n",
"\n",
2024-11-09 10:46:39 +04:00
"columns_to_drop = [\"Pregnancies\", \"SkinThickness\", \"BloodPressure\", \"Outcome\", \"DiabetesPedigreeFunction\"]\n",
2024-11-08 22:14:23 +04:00
"num_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop and df[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop and df[column].dtype == \"object\"\n",
"]\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Проверка работы конвеера"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 272,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
2024-11-09 10:46:39 +04:00
" <th>Glucose</th>\n",
2024-11-08 22:14:23 +04:00
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
2024-11-09 10:46:39 +04:00
" <th>Age</th>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>196</th>\n",
2024-11-09 10:46:39 +04:00
" <td>-0.478144</td>\n",
2024-11-08 22:14:23 +04:00
" <td>-0.688684</td>\n",
" <td>-0.946400</td>\n",
2024-11-09 10:46:39 +04:00
" <td>-1.029257</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
" <th>69</th>\n",
2024-11-09 10:46:39 +04:00
" <td>0.818506</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0.180416</td>\n",
" <td>-0.377190</td>\n",
2024-11-09 10:46:39 +04:00
" <td>-0.522334</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
" <th>494</th>\n",
2024-11-09 10:46:39 +04:00
" <td>-1.268784</td>\n",
2024-11-08 22:14:23 +04:00
" <td>-0.688684</td>\n",
" <td>-3.953317</td>\n",
2024-11-09 10:46:39 +04:00
" <td>-0.944770</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
" <th>463</th>\n",
2024-11-09 10:46:39 +04:00
" <td>-1.015779</td>\n",
2024-11-08 22:14:23 +04:00
" <td>-0.688684</td>\n",
" <td>-0.538054</td>\n",
2024-11-09 10:46:39 +04:00
" <td>0.322537</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
" <th>653</th>\n",
2024-11-09 10:46:39 +04:00
" <td>-0.003760</td>\n",
2024-11-08 22:14:23 +04:00
" <td>-0.688684</td>\n",
" <td>-0.637047</td>\n",
2024-11-09 10:46:39 +04:00
" <td>-0.522334</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>322</th>\n",
2024-11-09 10:46:39 +04:00
" <td>0.122742</td>\n",
2024-11-08 22:14:23 +04:00
" <td>-0.688684</td>\n",
" <td>-0.562802</td>\n",
2024-11-09 10:46:39 +04:00
" <td>0.238050</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
" <th>109</th>\n",
2024-11-09 10:46:39 +04:00
" <td>-0.794400</td>\n",
2024-11-08 22:14:23 +04:00
" <td>-0.375808</td>\n",
" <td>0.674613</td>\n",
2024-11-09 10:46:39 +04:00
" <td>-0.775796</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
2024-11-09 10:46:39 +04:00
" <td>-0.731149</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0.528056</td>\n",
" <td>-1.082516</td>\n",
2024-11-09 10:46:39 +04:00
" <td>-0.944770</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
" <th>651</th>\n",
2024-11-09 10:46:39 +04:00
" <td>-0.098637</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0.232562</td>\n",
" <td>0.229143</td>\n",
2024-11-09 10:46:39 +04:00
" <td>-0.522334</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
" <th>197</th>\n",
2024-11-09 10:46:39 +04:00
" <td>-0.414893</td>\n",
2024-11-08 22:14:23 +04:00
" <td>-0.271516</td>\n",
" <td>-1.119638</td>\n",
2024-11-09 10:46:39 +04:00
" <td>-0.860283</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 4 columns</p>\n",
"</div>"
],
"text/plain": [
2024-11-09 10:46:39 +04:00
" Glucose Insulin BMI Age\n",
"196 -0.478144 -0.688684 -0.946400 -1.029257\n",
"69 0.818506 0.180416 -0.377190 -0.522334\n",
"494 -1.268784 -0.688684 -3.953317 -0.944770\n",
"463 -1.015779 -0.688684 -0.538054 0.322537\n",
"653 -0.003760 -0.688684 -0.637047 -0.522334\n",
".. ... ... ... ...\n",
"322 0.122742 -0.688684 -0.562802 0.238050\n",
"109 -0.794400 -0.375808 0.674613 -0.775796\n",
"27 -0.731149 0.528056 -1.082516 -0.944770\n",
"651 -0.098637 0.232562 0.229143 -0.522334\n",
"197 -0.414893 -0.271516 -1.119638 -0.860283\n",
2024-11-08 22:14:23 +04:00
"\n",
"[614 rows x 4 columns]"
]
},
2024-11-09 10:46:39 +04:00
"execution_count": 272,
2024-11-08 22:14:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование набора моделей для классификации"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 273,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=9)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=9\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=9,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение моделей на обучающем наборе данных и оценка на тестовом"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 274,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Сводная таблица оценок качества для использованных моделей классификации\n",
"\n",
"Матрица неточностей"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 275,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
2024-11-09 10:46:39 +04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0kAAAQ9CAYAAACMbQYZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxUVf8H8M9lGQaBYRMYUEAQdyXXFHcJQ3NNfplmT+7mrqhlPqXivpSPpKKmGWRqpqamlpqZmvujuKRp5IKCsqggICjbzP39wePkBIMzMDALn/frdV8555458x3S++Xcc+45giiKIoiIiIiIiAgAYGHoAIiIiIiIiIwJO0lEREREREQvYCeJiIiIiIjoBewkERERERERvYCdJCIiIiIiohewk0RERERERPQCdpKIiIiIiIhewE4SERERERHRC9hJIiIiIiIiegE7SWRwMTExEAQBd+7cqZD279y5A0EQEBMTo5f2jh49CkEQcPToUb20R0REZE4iIiIgCIJWdQVBQERERMUGRFQG7CQRabB69Wq9dayIiIiIyHRYGToAoorm6+uLZ8+ewdraWqf3rV69GtWrV8eQIUPUyjt27Ihnz55BIpHoMUoiIiLz8Mknn+Cjjz4ydBhE5cJOEpk9QRAglUr11p6FhYVe2yMiIjIXOTk5sLOzg5UVf8Uk08bpdmSUVq9ejUaNGsHGxgZeXl4YN24cMjIyitWLioqCv78/bG1t8eqrr+L48ePo3LkzOnfurKpT0jNJKSkpGDp0KGrWrAkbGxt4enqiT58+queiatWqhT/++APHjh2DIAgQBEHVpqZnks6ePYs33ngDzs7OsLOzQ2BgID7//HP9/mCIiIiMxPNnj65du4Z33nkHzs7OaN++fYnPJOXl5SE8PBxubm5wcHBA7969ce/evRLbPXr0KFq2bAmpVIratWvjiy++0Pic06ZNm9CiRQvY2trCxcUFAwYMQGJiYoV8X6pa2M0noxMREYE5c+YgJCQEY8aMQVxcHNasWYNz587h5MmTqmlza9aswfjx49GhQweEh4fjzp076Nu3L5ydnVGzZs1SPyMsLAx//PEHJkyYgFq1auHBgwc4dOgQEhISUKtWLURGRmLChAmwt7fHxx9/DADw8PDQ2N6hQ4fQs2dPeHp6YtKkSZDL5bh+/Tr27duHSZMm6e+HQ0REZGTeeust1KlTBwsXLoQoinjw4EGxOiNGjMCmTZvwzjvvoG3btvj111/Ro0ePYvUuXryIbt26wdPTE3PmzIFCocDcuXPh5uZWrO6CBQswc+ZM9O/fHyNGjMDDhw+xcuVKdOzYERcvXoSTk1NFfF2qKkQiA4uOjhYBiPHx8eKDBw9EiUQivv7666JCoVDVWbVqlQhA/Oqrr0RRFMW8vDzR1dVVbNWqlVhQUKCqFxMTIwIQO3XqpCqLj48XAYjR0dGiKIri48ePRQDip59+WmpcjRo1UmvnuSNHjogAxCNHjoiiKIqFhYWin5+f6OvrKz5+/FitrlKp1P4HQUREZEJmz54tAhAHDhxYYvlzly5dEgGIY8eOVav3zjvviADE2bNnq8p69eolVqtWTbx//76q7MaNG6KVlZVam3fu3BEtLS3FBQsWqLV55coV0crKqlg5ka443Y6Myi+//IL8/HxMnjwZFhZ///UcOXIkZDIZfvzxRwDA+fPnkZaWhpEjR6rNex40aBCcnZ1L/QxbW1tIJBIcPXoUjx8/LnfMFy9eRHx8PCZPnlzsrpW2S6ASERGZqtGjR5d6/qeffgIATJw4Ua188uTJaq8VCgV++eUX9O3bF15eXqrygIAAdO/eXa3uzp07oVQq0b9/fzx69Eh1yOVy1KlTB0eOHCnHNyLidDsyMnfv3gUA1KtXT61cIpHA399fdf75fwMCAtTqWVlZoVatWqV+ho2NDZYsWYKpU6fCw8MDbdq0Qc+ePfHee+9BLpfrHPOtW7cAAI0bN9b5vURERKbOz8+v1PN3796FhYUFateurVb+z1z/4MEDPHv2rFhuB4rn+xs3bkAURdSpU6fEz9R1RVuif2IniaqkyZMno1evXti9ezcOHjyImTNnYtGiRfj111/RrFkzQ4dHRERkMmxtbSv9M5VKJQRBwP79+2FpaVnsvL29faXHROaF0+3IqPj6+gIA4uLi1Mrz8/MRHx+vOv/8vzdv3lSrV1hYqFqh7mVq166NqVOn4ueff8bVq1eRn5+PZcuWqc5rO1Xu+Z2xq1evalWfiIioKvH19YVSqVTNvHjun7ne3d0dUqm0WG4Hiuf72rVrQxRF+Pn5ISQkpNjRpk0b/X8RqlLYSSKjEhISAolEghUrVkAURVX5hg0bkJmZqVoJp2XLlnB1dcX69etRWFioqrd58+aXPmf09OlT5ObmqpXVrl0bDg4OyMvLU5XZ2dmVuOz4PzVv3hx+fn6IjIwsVv/F70BERFQVPX+eaMWKFWrlkZGRaq8tLS0REhKC3bt3IykpSVV+8+ZN7N+/X61uv379YGlpiTlz5hTLtaIoIi0tTY/fgKoiTrcjo+Lm5oYZM2Zgzpw56NatG3r37o24uDisXr0arVq1wrvvvgug6BmliIgITJgwAcHBwejfvz/u3LmDmJgY1K5du9RRoL/++guvvfYa+vfvj4YNG8LKygq7du1CamoqBgwYoKrXokULrFmzBvPnz0dAQADc3d0RHBxcrD0LCwusWbMGvXr1QtOmTTF06FB4enrizz//xB9//IGDBw/q/wdFRERkIpo2bYqBAwdi9erVyMzMRNu2bXH48OESR4wiIiLw888/o127dhgzZgwUCgVWrVqFxo0b49KlS6p6tWvXxvz58zFjxgzVFiAODg6Ij4/Hrl27MGrUKEybNq0SvyWZG3aSyOhERETAzc0Nq1atQnh4OFxcXDBq1CgsXLhQ7UHM8ePHQxRFLFu2DNOmTcMrr7yCPXv2YOLEiZBKpRrb9/b2xsCBA3H48GF88803sLKyQv369bFt2zaEhYWp6s2aNQt3797F0qVL8eTJE3Tq1KnEThIAhIaG4siRI5gzZw6WLVsGpVKJ2rVrY+TIkfr7wRAREZmor776Cm5ubti8eTN2796N4OBg/Pjjj/D29lar16JFC+zfvx/Tpk3DzJkz4e3tjblz5+L69ev4888/1ep+9NFHqFu3LpYvX445c+YAKMrxr7/+Onr37l1p343MkyByPhCZEaVSCTc3N/Tr1w/r1683dDhERESkB3379sUff/yBGzduGDoUqiL4TBKZrNzc3GLzkDdu3Ij09HR07tzZMEERERFRuTx79kzt9Y0bN/DTTz8xt1Ol4kgSmayjR48iPDwcb731FlxdXXHhwgVs2LABDRo0QGxsLCQSiaFDJCIiIh15enpiyJAhqv0R16xZg7y8PFy8eFHjvkhE+sZnkshk1apVC97e3lixYgXS09Ph4uKC9957D4sXL2YHiYiIyER169YN3377LVJSUmBjY4OgoCAsXLiQHSSqVBxJIiIiIiIiegGfSSIiIiIiInoBO0lEREREREQv4DNJlUypVCIpKQkODg6lbnhKZI5EUcSTJ0/g5eUFCwv936PJzc1Ffn5+qXUkEkmp+2gRUdXD3ExVXUXmZ1PNzewkVbKkpKRiG6cRVTWJiYmoWbOmXtvMzc2Fn689Uh4oSq0nl8sRHx9vdBdjIjIc5maiIvrOz6acm9lJqmQODg4AgLsXakFmz9mOhvBm3SaGDqHKKkQBTuAn1b8DfcrPz0fKAwVunveGzKHkf1tZT5QIaJmI/Px8o7oQE5FhMTcbXoelww0dQpWmyM/F9Y3z9J6fTTk3s5NUyZ4P48vsLTT+ZaGKZSVYGzqEqut/a2lW5HQWewcB9g4lt68Ep9EQUXHMzYZnKTGeX46rsorKz6aYm9lJIiKzUiAqUKBhZ4MCUVnJ0RAREZEp5mZ2kojIrCghQomSL8SayomIiKjimGJuZieJiMyKEiIUJnYhJiIiMmemmJvZSSIis1IgKlGg4XprrEP6RERE5swUczM7SURkVpT/OzSdIyIiosplirmZnSQiMiuKUob0NZUTERFRxTH
2024-11-08 22:14:23 +04:00
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Healthy\", \"Sick\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 276,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row0_col0 {\n",
" background-color: #3aba76;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row0_col1, #T_09af8_row1_col0, #T_09af8_row1_col2, #T_09af8_row7_col3 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row0_col2 {\n",
" background-color: #24878e;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row0_col3 {\n",
" background-color: #23888e;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row0_col4, #T_09af8_row4_col7 {\n",
" background-color: #ac2694;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row0_col5, #T_09af8_row1_col4, #T_09af8_row1_col5, #T_09af8_row1_col6, #T_09af8_row1_col7 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row0_col6 {\n",
" background-color: #6f00a8;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row0_col7 {\n",
" background-color: #c6417d;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row1_col1 {\n",
" background-color: #7fd34e;\n",
2024-11-08 22:14:23 +04:00
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row1_col3, #T_09af8_row4_col1, #T_09af8_row6_col3 {\n",
" background-color: #44bf70;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row2_col0, #T_09af8_row5_col3 {\n",
" background-color: #37b878;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_09af8_row2_col1 {\n",
" background-color: #a2da37;\n",
2024-11-08 22:14:23 +04:00
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row2_col2, #T_09af8_row2_col3, #T_09af8_row7_col0, #T_09af8_row7_col1 {\n",
" background-color: #26818e;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row2_col4 {\n",
" background-color: #aa2395;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row2_col5 {\n",
" background-color: #d8576b;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row2_col6 {\n",
" background-color: #6600a7;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row2_col7 {\n",
" background-color: #bf3984;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row3_col0 {\n",
" background-color: #95d840;\n",
" color: #000000;\n",
2024-11-08 22:14:23 +04:00
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row3_col1, #T_09af8_row3_col2 {\n",
" background-color: #6ece58;\n",
" color: #000000;\n",
2024-11-08 22:14:23 +04:00
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row3_col3 {\n",
" background-color: #21918c;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row3_col4 {\n",
" background-color: #d24f71;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row3_col5 {\n",
" background-color: #d04d73;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row3_col6 {\n",
" background-color: #cc4778;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row3_col7 {\n",
" background-color: #b7318a;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row4_col0 {\n",
" background-color: #3bbb75;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row4_col2 {\n",
" background-color: #22a785;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row4_col3 {\n",
" background-color: #1fa188;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row4_col4 {\n",
" background-color: #b42e8d;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row4_col5 {\n",
" background-color: #c7427c;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row4_col6 {\n",
" background-color: #8e0ca4;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row5_col0 {\n",
" background-color: #23a983;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row5_col1 {\n",
" background-color: #38b977;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row5_col2 {\n",
" background-color: #24aa83;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row5_col4 {\n",
" background-color: #a62098;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row5_col5, #T_09af8_row6_col4 {\n",
" background-color: #c33d80;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row5_col6 {\n",
" background-color: #7b02a8;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row5_col7, #T_09af8_row6_col6 {\n",
" background-color: #b12a90;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row6_col0 {\n",
" background-color: #56c667;\n",
2024-11-08 22:14:23 +04:00
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row6_col1 {\n",
" background-color: #34b679;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row6_col2 {\n",
" background-color: #52c569;\n",
" color: #000000;\n",
2024-11-09 09:14:53 +04:00
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row6_col5 {\n",
" background-color: #c03a83;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row6_col7 {\n",
" background-color: #b22b8f;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row7_col2 {\n",
" background-color: #7ad151;\n",
2024-11-08 22:14:23 +04:00
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_09af8_row7_col4, #T_09af8_row7_col5, #T_09af8_row7_col6, #T_09af8_row7_col7 {\n",
" background-color: #4e02a2;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
2024-11-09 10:46:39 +04:00
"<table id=\"T_09af8\">\n",
2024-11-08 22:14:23 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_09af8_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_09af8_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_09af8_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_09af8_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_09af8_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_09af8_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_09af8_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_09af8_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_09af8_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_09af8_row0_col0\" class=\"data row0 col0\" >0.710843</td>\n",
" <td id=\"T_09af8_row0_col1\" class=\"data row0 col1\" >0.714286</td>\n",
" <td id=\"T_09af8_row0_col2\" class=\"data row0 col2\" >0.551402</td>\n",
" <td id=\"T_09af8_row0_col3\" class=\"data row0 col3\" >0.648148</td>\n",
" <td id=\"T_09af8_row0_col4\" class=\"data row0 col4\" >0.765472</td>\n",
" <td id=\"T_09af8_row0_col5\" class=\"data row0 col5\" >0.785714</td>\n",
" <td id=\"T_09af8_row0_col6\" class=\"data row0 col6\" >0.621053</td>\n",
" <td id=\"T_09af8_row0_col7\" class=\"data row0 col7\" >0.679612</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_09af8_level0_row1\" class=\"row_heading level0 row1\" >random_forest</th>\n",
" <td id=\"T_09af8_row1_col0\" class=\"data row1 col0\" >0.977169</td>\n",
" <td id=\"T_09af8_row1_col1\" class=\"data row1 col1\" >0.666667</td>\n",
" <td id=\"T_09af8_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_09af8_row1_col3\" class=\"data row1 col3\" >0.777778</td>\n",
" <td id=\"T_09af8_row1_col4\" class=\"data row1 col4\" >0.991857</td>\n",
" <td id=\"T_09af8_row1_col5\" class=\"data row1 col5\" >0.785714</td>\n",
" <td id=\"T_09af8_row1_col6\" class=\"data row1 col6\" >0.988453</td>\n",
" <td id=\"T_09af8_row1_col7\" class=\"data row1 col7\" >0.717949</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_09af8_level0_row2\" class=\"row_heading level0 row2\" >naive_bayes</th>\n",
" <td id=\"T_09af8_row2_col0\" class=\"data row2 col0\" >0.702532</td>\n",
" <td id=\"T_09af8_row2_col1\" class=\"data row2 col1\" >0.708333</td>\n",
" <td id=\"T_09af8_row2_col2\" class=\"data row2 col2\" >0.518692</td>\n",
" <td id=\"T_09af8_row2_col3\" class=\"data row2 col3\" >0.629630</td>\n",
" <td id=\"T_09af8_row2_col4\" class=\"data row2 col4\" >0.755700</td>\n",
" <td id=\"T_09af8_row2_col5\" class=\"data row2 col5\" >0.779221</td>\n",
" <td id=\"T_09af8_row2_col6\" class=\"data row2 col6\" >0.596774</td>\n",
" <td id=\"T_09af8_row2_col7\" class=\"data row2 col7\" >0.666667</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_09af8_level0_row3\" class=\"row_heading level0 row3\" >gradient_boosting</th>\n",
" <td id=\"T_09af8_row3_col0\" class=\"data row3 col0\" >0.941463</td>\n",
" <td id=\"T_09af8_row3_col1\" class=\"data row3 col1\" >0.642857</td>\n",
" <td id=\"T_09af8_row3_col2\" class=\"data row3 col2\" >0.901869</td>\n",
" <td id=\"T_09af8_row3_col3\" class=\"data row3 col3\" >0.666667</td>\n",
" <td id=\"T_09af8_row3_col4\" class=\"data row3 col4\" >0.946254</td>\n",
" <td id=\"T_09af8_row3_col5\" class=\"data row3 col5\" >0.753247</td>\n",
" <td id=\"T_09af8_row3_col6\" class=\"data row3 col6\" >0.921241</td>\n",
" <td id=\"T_09af8_row3_col7\" class=\"data row3 col7\" >0.654545</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_09af8_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
" <td id=\"T_09af8_row4_col0\" class=\"data row4 col0\" >0.716346</td>\n",
" <td id=\"T_09af8_row4_col1\" class=\"data row4 col1\" >0.584615</td>\n",
" <td id=\"T_09af8_row4_col2\" class=\"data row4 col2\" >0.696262</td>\n",
" <td id=\"T_09af8_row4_col3\" class=\"data row4 col3\" >0.703704</td>\n",
" <td id=\"T_09af8_row4_col4\" class=\"data row4 col4\" >0.798046</td>\n",
" <td id=\"T_09af8_row4_col5\" class=\"data row4 col5\" >0.720779</td>\n",
" <td id=\"T_09af8_row4_col6\" class=\"data row4 col6\" >0.706161</td>\n",
" <td id=\"T_09af8_row4_col7\" class=\"data row4 col7\" >0.638655</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_09af8_level0_row5\" class=\"row_heading level0 row5\" >ridge</th>\n",
" <td id=\"T_09af8_row5_col0\" class=\"data row5 col0\" >0.610442</td>\n",
" <td id=\"T_09af8_row5_col1\" class=\"data row5 col1\" >0.561644</td>\n",
" <td id=\"T_09af8_row5_col2\" class=\"data row5 col2\" >0.710280</td>\n",
" <td id=\"T_09af8_row5_col3\" class=\"data row5 col3\" >0.759259</td>\n",
" <td id=\"T_09af8_row5_col4\" class=\"data row5 col4\" >0.741042</td>\n",
" <td id=\"T_09af8_row5_col5\" class=\"data row5 col5\" >0.707792</td>\n",
" <td id=\"T_09af8_row5_col6\" class=\"data row5 col6\" >0.656587</td>\n",
" <td id=\"T_09af8_row5_col7\" class=\"data row5 col7\" >0.645669</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_09af8_level0_row6\" class=\"row_heading level0 row6\" >decision_tree</th>\n",
" <td id=\"T_09af8_row6_col0\" class=\"data row6 col0\" >0.793860</td>\n",
" <td id=\"T_09af8_row6_col1\" class=\"data row6 col1\" >0.552632</td>\n",
" <td id=\"T_09af8_row6_col2\" class=\"data row6 col2\" >0.845794</td>\n",
" <td id=\"T_09af8_row6_col3\" class=\"data row6 col3\" >0.777778</td>\n",
" <td id=\"T_09af8_row6_col4\" class=\"data row6 col4\" >0.869707</td>\n",
" <td id=\"T_09af8_row6_col5\" class=\"data row6 col5\" >0.701299</td>\n",
" <td id=\"T_09af8_row6_col6\" class=\"data row6 col6\" >0.819005</td>\n",
" <td id=\"T_09af8_row6_col7\" class=\"data row6 col7\" >0.646154</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_09af8_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_09af8_row7_col0\" class=\"data row7 col0\" >0.379576</td>\n",
" <td id=\"T_09af8_row7_col1\" class=\"data row7 col1\" >0.376000</td>\n",
" <td id=\"T_09af8_row7_col2\" class=\"data row7 col2\" >0.920561</td>\n",
" <td id=\"T_09af8_row7_col3\" class=\"data row7 col3\" >0.870370</td>\n",
" <td id=\"T_09af8_row7_col4\" class=\"data row7 col4\" >0.447883</td>\n",
" <td id=\"T_09af8_row7_col5\" class=\"data row7 col5\" >0.448052</td>\n",
" <td id=\"T_09af8_row7_col6\" class=\"data row7 col6\" >0.537517</td>\n",
" <td id=\"T_09af8_row7_col7\" class=\"data row7 col7\" >0.525140</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-11-09 10:46:39 +04:00
"<pandas.io.formats.style.Styler at 0x203b54857c0>"
2024-11-08 22:14:23 +04:00
]
},
2024-11-09 10:46:39 +04:00
"execution_count": 276,
2024-11-08 22:14:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 277,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row0_col0, #T_4cc76_row0_col1, #T_4cc76_row2_col0 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row0_col2, #T_4cc76_row0_col3, #T_4cc76_row0_col4 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row1_col0 {\n",
" background-color: #8bd646;\n",
2024-11-08 22:14:23 +04:00
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row1_col1 {\n",
" background-color: #52c569;\n",
" color: #000000;\n",
"}\n",
"#T_4cc76_row1_col2 {\n",
" background-color: #d24f71;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row1_col3, #T_4cc76_row4_col2 {\n",
" background-color: #c9447a;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row1_col4 {\n",
" background-color: #c6417d;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row2_col1 {\n",
" background-color: #70cf57;\n",
" color: #000000;\n",
"}\n",
"#T_4cc76_row2_col2 {\n",
" background-color: #ce4b75;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row2_col3 {\n",
" background-color: #d5536f;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row2_col4 {\n",
" background-color: #d45270;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row3_col0 {\n",
" background-color: #67cc5c;\n",
2024-11-08 22:14:23 +04:00
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row3_col1 {\n",
" background-color: #46c06f;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row3_col2 {\n",
" background-color: #cd4a76;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row3_col3, #T_4cc76_row3_col4, #T_4cc76_row5_col4, #T_4cc76_row6_col4 {\n",
" background-color: #bb3488;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row4_col0 {\n",
" background-color: #a2da37;\n",
" color: #000000;\n",
2024-11-08 22:14:23 +04:00
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row4_col1 {\n",
" background-color: #60ca60;\n",
" color: #000000;\n",
2024-11-08 22:14:23 +04:00
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row4_col3 {\n",
" background-color: #d14e72;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row4_col4 {\n",
" background-color: #cf4c74;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row5_col0 {\n",
" background-color: #73d056;\n",
2024-11-08 22:14:23 +04:00
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row5_col1 {\n",
" background-color: #3fbc73;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row5_col2 {\n",
" background-color: #c23c81;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row5_col3, #T_4cc76_row6_col2 {\n",
" background-color: #bd3786;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row6_col0 {\n",
" background-color: #63cb5f;\n",
" color: #000000;\n",
"}\n",
"#T_4cc76_row6_col1 {\n",
" background-color: #48c16e;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row6_col3 {\n",
" background-color: #ba3388;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row7_col0, #T_4cc76_row7_col1 {\n",
" background-color: #26818e;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_4cc76_row7_col2, #T_4cc76_row7_col3, #T_4cc76_row7_col4 {\n",
" background-color: #4e02a2;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
2024-11-09 10:46:39 +04:00
"<table id=\"T_4cc76\">\n",
2024-11-08 22:14:23 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_4cc76_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_4cc76_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_4cc76_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_4cc76_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_4cc76_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_4cc76_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
" <td id=\"T_4cc76_row0_col0\" class=\"data row0 col0\" >0.785714</td>\n",
" <td id=\"T_4cc76_row0_col1\" class=\"data row0 col1\" >0.717949</td>\n",
" <td id=\"T_4cc76_row0_col2\" class=\"data row0 col2\" >0.867222</td>\n",
" <td id=\"T_4cc76_row0_col3\" class=\"data row0 col3\" >0.546816</td>\n",
" <td id=\"T_4cc76_row0_col4\" class=\"data row0 col4\" >0.551041</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_4cc76_level0_row1\" class=\"row_heading level0 row1\" >gradient_boosting</th>\n",
" <td id=\"T_4cc76_row1_col0\" class=\"data row1 col0\" >0.753247</td>\n",
" <td id=\"T_4cc76_row1_col1\" class=\"data row1 col1\" >0.654545</td>\n",
" <td id=\"T_4cc76_row1_col2\" class=\"data row1 col2\" >0.845741</td>\n",
" <td id=\"T_4cc76_row1_col3\" class=\"data row1 col3\" >0.462725</td>\n",
" <td id=\"T_4cc76_row1_col4\" class=\"data row1 col4\" >0.462910</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_4cc76_level0_row2\" class=\"row_heading level0 row2\" >logistic</th>\n",
" <td id=\"T_4cc76_row2_col0\" class=\"data row2 col0\" >0.785714</td>\n",
" <td id=\"T_4cc76_row2_col1\" class=\"data row2 col1\" >0.679612</td>\n",
" <td id=\"T_4cc76_row2_col2\" class=\"data row2 col2\" >0.835556</td>\n",
" <td id=\"T_4cc76_row2_col3\" class=\"data row2 col3\" >0.519205</td>\n",
" <td id=\"T_4cc76_row2_col4\" class=\"data row2 col4\" >0.520588</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_4cc76_level0_row3\" class=\"row_heading level0 row3\" >ridge</th>\n",
" <td id=\"T_4cc76_row3_col0\" class=\"data row3 col0\" >0.707792</td>\n",
" <td id=\"T_4cc76_row3_col1\" class=\"data row3 col1\" >0.645669</td>\n",
" <td id=\"T_4cc76_row3_col2\" class=\"data row3 col2\" >0.833889</td>\n",
" <td id=\"T_4cc76_row3_col3\" class=\"data row3 col3\" >0.406373</td>\n",
" <td id=\"T_4cc76_row3_col4\" class=\"data row3 col4\" >0.419772</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_4cc76_level0_row4\" class=\"row_heading level0 row4\" >naive_bayes</th>\n",
" <td id=\"T_4cc76_row4_col0\" class=\"data row4 col0\" >0.779221</td>\n",
" <td id=\"T_4cc76_row4_col1\" class=\"data row4 col1\" >0.666667</td>\n",
" <td id=\"T_4cc76_row4_col2\" class=\"data row4 col2\" >0.822593</td>\n",
" <td id=\"T_4cc76_row4_col3\" class=\"data row4 col3\" >0.502471</td>\n",
" <td id=\"T_4cc76_row4_col4\" class=\"data row4 col4\" >0.504419</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_4cc76_level0_row5\" class=\"row_heading level0 row5\" >knn</th>\n",
" <td id=\"T_4cc76_row5_col0\" class=\"data row5 col0\" >0.720779</td>\n",
" <td id=\"T_4cc76_row5_col1\" class=\"data row5 col1\" >0.638655</td>\n",
" <td id=\"T_4cc76_row5_col2\" class=\"data row5 col2\" >0.806296</td>\n",
" <td id=\"T_4cc76_row5_col3\" class=\"data row5 col3\" >0.414293</td>\n",
" <td id=\"T_4cc76_row5_col4\" class=\"data row5 col4\" >0.419023</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_4cc76_level0_row6\" class=\"row_heading level0 row6\" >decision_tree</th>\n",
" <td id=\"T_4cc76_row6_col0\" class=\"data row6 col0\" >0.701299</td>\n",
" <td id=\"T_4cc76_row6_col1\" class=\"data row6 col1\" >0.646154</td>\n",
" <td id=\"T_4cc76_row6_col2\" class=\"data row6 col2\" >0.794167</td>\n",
" <td id=\"T_4cc76_row6_col3\" class=\"data row6 col3\" >0.400271</td>\n",
" <td id=\"T_4cc76_row6_col4\" class=\"data row6 col4\" >0.417827</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_4cc76_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_4cc76_row7_col0\" class=\"data row7 col0\" >0.448052</td>\n",
" <td id=\"T_4cc76_row7_col1\" class=\"data row7 col1\" >0.525140</td>\n",
" <td id=\"T_4cc76_row7_col2\" class=\"data row7 col2\" >0.603333</td>\n",
" <td id=\"T_4cc76_row7_col3\" class=\"data row7 col3\" >0.069387</td>\n",
" <td id=\"T_4cc76_row7_col4\" class=\"data row7 col4\" >0.110298</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-11-09 10:46:39 +04:00
"<pandas.io.formats.style.Styler at 0x203bb436870>"
2024-11-08 22:14:23 +04:00
]
},
2024-11-09 10:46:39 +04:00
"execution_count": 277,
2024-11-08 22:14:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 278,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-11-09 10:46:39 +04:00
"'random_forest'"
2024-11-08 22:14:23 +04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 279,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-11-09 10:46:39 +04:00
"'Error items count: 33'"
2024-11-08 22:14:23 +04:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Predicted</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th>46</th>\n",
2024-11-08 22:14:23 +04:00
" <td>1</td>\n",
2024-11-09 10:46:39 +04:00
" <td>1</td>\n",
" <td>146</td>\n",
" <td>56</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0</td>\n",
" <td>0</td>\n",
2024-11-09 10:46:39 +04:00
" <td>29.7</td>\n",
" <td>0.564</td>\n",
" <td>29</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>86</th>\n",
" <td>13</td>\n",
" <td>1</td>\n",
" <td>106</td>\n",
" <td>72</td>\n",
" <td>54</td>\n",
" <td>0</td>\n",
" <td>36.6</td>\n",
" <td>0.178</td>\n",
" <td>45</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>91</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>123</td>\n",
" <td>80</td>\n",
" <td>15</td>\n",
" <td>176</td>\n",
" <td>32.0</td>\n",
" <td>0.443</td>\n",
" <td>34</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>95</th>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>144</td>\n",
" <td>72</td>\n",
" <td>27</td>\n",
" <td>228</td>\n",
" <td>33.9</td>\n",
" <td>0.255</td>\n",
" <td>40</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th>125</th>\n",
2024-11-08 22:14:23 +04:00
" <td>1</td>\n",
" <td>0</td>\n",
2024-11-09 10:46:39 +04:00
" <td>88</td>\n",
" <td>30</td>\n",
2024-11-08 22:14:23 +04:00
" <td>42</td>\n",
2024-11-09 10:46:39 +04:00
" <td>99</td>\n",
" <td>55.0</td>\n",
" <td>0.496</td>\n",
" <td>26</td>\n",
" <td>1</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th>167</th>\n",
" <td>4</td>\n",
2024-11-08 22:14:23 +04:00
" <td>1</td>\n",
2024-11-09 10:46:39 +04:00
" <td>120</td>\n",
" <td>68</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0</td>\n",
" <td>0</td>\n",
2024-11-09 10:46:39 +04:00
" <td>29.6</td>\n",
" <td>0.709</td>\n",
" <td>34</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>188</th>\n",
" <td>8</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0</td>\n",
2024-11-09 10:46:39 +04:00
" <td>109</td>\n",
" <td>76</td>\n",
" <td>39</td>\n",
" <td>114</td>\n",
" <td>27.9</td>\n",
" <td>0.640</td>\n",
" <td>31</td>\n",
" <td>1</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
" <th>204</th>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>103</td>\n",
" <td>72</td>\n",
" <td>32</td>\n",
" <td>190</td>\n",
" <td>37.7</td>\n",
" <td>0.324</td>\n",
" <td>55</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>228</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>197</td>\n",
" <td>70</td>\n",
" <td>39</td>\n",
" <td>744</td>\n",
" <td>36.7</td>\n",
" <td>2.329</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th>274</th>\n",
" <td>13</td>\n",
2024-11-08 22:14:23 +04:00
" <td>1</td>\n",
2024-11-09 10:46:39 +04:00
" <td>106</td>\n",
" <td>70</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0</td>\n",
" <td>0</td>\n",
2024-11-09 10:46:39 +04:00
" <td>34.2</td>\n",
" <td>0.251</td>\n",
" <td>52</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th>280</th>\n",
2024-11-08 22:14:23 +04:00
" <td>0</td>\n",
" <td>0</td>\n",
" <td>146</td>\n",
" <td>70</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>37.9</td>\n",
" <td>0.334</td>\n",
" <td>28</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>282</th>\n",
" <td>7</td>\n",
" <td>1</td>\n",
" <td>133</td>\n",
" <td>88</td>\n",
" <td>15</td>\n",
" <td>155</td>\n",
" <td>32.4</td>\n",
" <td>0.262</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>309</th>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>124</td>\n",
" <td>68</td>\n",
" <td>28</td>\n",
" <td>205</td>\n",
" <td>32.9</td>\n",
" <td>0.875</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>335</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>165</td>\n",
" <td>76</td>\n",
" <td>43</td>\n",
" <td>255</td>\n",
" <td>47.9</td>\n",
" <td>0.259</td>\n",
" <td>26</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th>363</th>\n",
2024-11-08 22:14:23 +04:00
" <td>4</td>\n",
" <td>0</td>\n",
2024-11-09 10:46:39 +04:00
" <td>146</td>\n",
" <td>78</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0</td>\n",
" <td>0</td>\n",
2024-11-09 10:46:39 +04:00
" <td>38.5</td>\n",
" <td>0.520</td>\n",
" <td>67</td>\n",
" <td>1</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
" <th>397</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>131</td>\n",
" <td>66</td>\n",
" <td>40</td>\n",
" <td>0</td>\n",
" <td>34.3</td>\n",
" <td>0.196</td>\n",
" <td>22</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th>510</th>\n",
" <td>12</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0</td>\n",
2024-11-09 10:46:39 +04:00
" <td>84</td>\n",
2024-11-08 22:14:23 +04:00
" <td>72</td>\n",
2024-11-09 10:46:39 +04:00
" <td>31</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0</td>\n",
2024-11-09 10:46:39 +04:00
" <td>29.7</td>\n",
" <td>0.297</td>\n",
2024-11-08 22:14:23 +04:00
" <td>46</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th>517</th>\n",
" <td>7</td>\n",
2024-11-08 22:14:23 +04:00
" <td>1</td>\n",
2024-11-09 10:46:39 +04:00
" <td>125</td>\n",
" <td>86</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>37.6</td>\n",
" <td>0.304</td>\n",
" <td>51</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th>536</th>\n",
" <td>0</td>\n",
2024-11-08 22:14:23 +04:00
" <td>1</td>\n",
2024-11-09 10:46:39 +04:00
" <td>105</td>\n",
" <td>90</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>29.6</td>\n",
" <td>0.197</td>\n",
" <td>46</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th>541</th>\n",
2024-11-08 22:14:23 +04:00
" <td>3</td>\n",
" <td>0</td>\n",
2024-11-09 10:46:39 +04:00
" <td>128</td>\n",
" <td>72</td>\n",
" <td>25</td>\n",
" <td>190</td>\n",
" <td>32.4</td>\n",
" <td>0.549</td>\n",
" <td>27</td>\n",
2024-11-08 22:14:23 +04:00
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th>549</th>\n",
" <td>4</td>\n",
2024-11-08 22:14:23 +04:00
" <td>1</td>\n",
2024-11-09 10:46:39 +04:00
" <td>189</td>\n",
" <td>110</td>\n",
" <td>31</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0</td>\n",
2024-11-09 10:46:39 +04:00
" <td>28.5</td>\n",
" <td>0.680</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>568</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>154</td>\n",
" <td>72</td>\n",
" <td>29</td>\n",
" <td>126</td>\n",
" <td>31.3</td>\n",
" <td>0.338</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>577</th>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>118</td>\n",
" <td>80</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0</td>\n",
" <td>0</td>\n",
2024-11-09 10:46:39 +04:00
" <td>42.9</td>\n",
" <td>0.693</td>\n",
" <td>21</td>\n",
" <td>1</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
" <th>583</th>\n",
" <td>8</td>\n",
" <td>1</td>\n",
" <td>100</td>\n",
" <td>76</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>38.7</td>\n",
" <td>0.190</td>\n",
" <td>42</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th>590</th>\n",
" <td>11</td>\n",
" <td>0</td>\n",
" <td>111</td>\n",
" <td>84</td>\n",
" <td>40</td>\n",
" <td>0</td>\n",
" <td>46.8</td>\n",
" <td>0.925</td>\n",
" <td>45</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-08 22:14:23 +04:00
" <th>594</th>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>123</td>\n",
" <td>72</td>\n",
" <td>45</td>\n",
" <td>230</td>\n",
" <td>33.6</td>\n",
" <td>0.733</td>\n",
" <td>34</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>622</th>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>183</td>\n",
" <td>94</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40.8</td>\n",
" <td>1.461</td>\n",
" <td>45</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>630</th>\n",
" <td>7</td>\n",
" <td>0</td>\n",
" <td>114</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>27.4</td>\n",
" <td>0.732</td>\n",
" <td>34</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>658</th>\n",
" <td>11</td>\n",
" <td>1</td>\n",
" <td>127</td>\n",
" <td>106</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>39.0</td>\n",
" <td>0.190</td>\n",
" <td>51</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>669</th>\n",
" <td>9</td>\n",
" <td>1</td>\n",
" <td>154</td>\n",
" <td>78</td>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>30.9</td>\n",
" <td>0.164</td>\n",
" <td>45</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>725</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>112</td>\n",
" <td>78</td>\n",
" <td>40</td>\n",
" <td>0</td>\n",
" <td>39.4</td>\n",
" <td>0.236</td>\n",
" <td>38</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>744</th>\n",
" <td>13</td>\n",
" <td>1</td>\n",
" <td>153</td>\n",
" <td>88</td>\n",
" <td>37</td>\n",
" <td>140</td>\n",
" <td>40.6</td>\n",
" <td>1.174</td>\n",
" <td>39</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>750</th>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>136</td>\n",
" <td>70</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>31.2</td>\n",
" <td>1.182</td>\n",
" <td>22</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Predicted Glucose BloodPressure SkinThickness Insulin \\\n",
2024-11-09 10:46:39 +04:00
"46 1 1 146 56 0 0 \n",
2024-11-08 22:14:23 +04:00
"86 13 1 106 72 54 0 \n",
"91 4 1 123 80 15 176 \n",
"95 6 1 144 72 27 228 \n",
2024-11-09 10:46:39 +04:00
"125 1 0 88 30 42 99 \n",
"167 4 1 120 68 0 0 \n",
"188 8 0 109 76 39 114 \n",
2024-11-08 22:14:23 +04:00
"204 6 1 103 72 32 190 \n",
"228 4 1 197 70 39 744 \n",
"274 13 1 106 70 0 0 \n",
"280 0 0 146 70 0 0 \n",
"282 7 1 133 88 15 155 \n",
"309 2 0 124 68 28 205 \n",
"335 0 1 165 76 43 255 \n",
2024-11-09 10:46:39 +04:00
"363 4 0 146 78 0 0 \n",
2024-11-08 22:14:23 +04:00
"397 0 0 131 66 40 0 \n",
2024-11-09 10:46:39 +04:00
"510 12 0 84 72 31 0 \n",
2024-11-08 22:14:23 +04:00
"517 7 1 125 86 0 0 \n",
2024-11-09 10:46:39 +04:00
"536 0 1 105 90 0 0 \n",
"541 3 0 128 72 25 190 \n",
"549 4 1 189 110 31 0 \n",
"568 4 1 154 72 29 126 \n",
"577 2 0 118 80 0 0 \n",
2024-11-08 22:14:23 +04:00
"583 8 1 100 76 0 0 \n",
2024-11-09 10:46:39 +04:00
"590 11 0 111 84 40 0 \n",
2024-11-08 22:14:23 +04:00
"594 6 1 123 72 45 230 \n",
"622 6 1 183 94 0 0 \n",
"630 7 0 114 64 0 0 \n",
"658 11 1 127 106 0 0 \n",
"669 9 1 154 78 30 100 \n",
"725 4 1 112 78 40 0 \n",
"744 13 1 153 88 37 140 \n",
"750 4 0 136 70 0 0 \n",
"\n",
" BMI DiabetesPedigreeFunction Age Outcome \n",
2024-11-09 10:46:39 +04:00
"46 29.7 0.564 29 0 \n",
2024-11-08 22:14:23 +04:00
"86 36.6 0.178 45 0 \n",
"91 32.0 0.443 34 0 \n",
"95 33.9 0.255 40 0 \n",
2024-11-09 10:46:39 +04:00
"125 55.0 0.496 26 1 \n",
"167 29.6 0.709 34 0 \n",
"188 27.9 0.640 31 1 \n",
2024-11-08 22:14:23 +04:00
"204 37.7 0.324 55 0 \n",
"228 36.7 2.329 31 0 \n",
"274 34.2 0.251 52 0 \n",
"280 37.9 0.334 28 1 \n",
"282 32.4 0.262 37 0 \n",
"309 32.9 0.875 30 1 \n",
"335 47.9 0.259 26 0 \n",
2024-11-09 10:46:39 +04:00
"363 38.5 0.520 67 1 \n",
2024-11-08 22:14:23 +04:00
"397 34.3 0.196 22 1 \n",
2024-11-09 10:46:39 +04:00
"510 29.7 0.297 46 1 \n",
2024-11-08 22:14:23 +04:00
"517 37.6 0.304 51 0 \n",
2024-11-09 10:46:39 +04:00
"536 29.6 0.197 46 0 \n",
"541 32.4 0.549 27 1 \n",
"549 28.5 0.680 37 0 \n",
"568 31.3 0.338 37 0 \n",
"577 42.9 0.693 21 1 \n",
2024-11-08 22:14:23 +04:00
"583 38.7 0.190 42 0 \n",
2024-11-09 10:46:39 +04:00
"590 46.8 0.925 45 1 \n",
2024-11-08 22:14:23 +04:00
"594 33.6 0.733 34 0 \n",
"622 40.8 1.461 45 0 \n",
"630 27.4 0.732 34 1 \n",
"658 39.0 0.190 51 0 \n",
"669 30.9 0.164 45 0 \n",
"725 39.4 0.236 38 0 \n",
"744 40.6 1.174 39 0 \n",
"750 31.2 1.182 22 1 "
]
},
2024-11-09 10:46:39 +04:00
"execution_count": 279,
2024-11-08 22:14:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"Outcome\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Пример использования обученной модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 280,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>450</th>\n",
" <td>1.0</td>\n",
" <td>82.0</td>\n",
" <td>64.0</td>\n",
" <td>13.0</td>\n",
" <td>95.0</td>\n",
" <td>21.2</td>\n",
" <td>0.415</td>\n",
" <td>23.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"450 1.0 82.0 64.0 13.0 95.0 21.2 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"450 0.415 23.0 0.0 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
2024-11-09 10:46:39 +04:00
" <th>Glucose</th>\n",
2024-11-08 22:14:23 +04:00
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
2024-11-09 10:46:39 +04:00
" <th>Age</th>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>450</th>\n",
2024-11-09 10:46:39 +04:00
" <td>-1.205533</td>\n",
2024-11-08 22:14:23 +04:00
" <td>0.136961</td>\n",
" <td>-1.329999</td>\n",
2024-11-09 10:46:39 +04:00
" <td>-0.860283</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
2024-11-09 10:46:39 +04:00
" Glucose Insulin BMI Age\n",
"450 -1.205533 0.136961 -1.329999 -0.860283"
2024-11-08 22:14:23 +04:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
2024-11-09 10:46:39 +04:00
"'predicted: 0 (proba: [0.96 0.04])'"
2024-11-08 22:14:23 +04:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 0'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 450\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 281,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"import pandas as pd\n",
"\n",
"\n",
"# Определяем числовые признаки\n",
"numeric_features = X_train.select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
"\n",
"# Установка random_state\n",
2024-11-09 09:14:53 +04:00
"random_state = 9\n",
2024-11-08 22:14:23 +04:00
"\n",
"# Определение трансформера\n",
"pipeline_end = ColumnTransformer([\n",
" ('numeric', StandardScaler(), numeric_features),\n",
" # Добавьте другие трансформеры, если требуется\n",
"])\n",
"\n",
"# Объявление модели\n",
"optimized_model = RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"sqrt\",\n",
" n_estimators=10,\n",
")\n",
"\n",
"# Создание пайплайна с корректными шагами\n",
"result = {}\n",
"\n",
"# Обучение модели\n",
"result[\"pipeline\"] = Pipeline([\n",
" (\"pipeline\", pipeline_end),\n",
" (\"model\", optimized_model)\n",
"]).fit(X_train, y_train.values.ravel())\n",
"\n",
"# Прогнозирование и расчет метрик\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"# Метрики для оценки модели\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 282,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [],
"source": [
2024-11-09 09:14:53 +04:00
"optimized_model_type = \"random_forest\"\n",
2024-11-08 22:14:23 +04:00
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 283,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-11-09 10:46:39 +04:00
"#T_cba1c_row0_col0, #T_cba1c_row0_col1, #T_cba1c_row0_col3 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_cba1c_row0_col2, #T_cba1c_row1_col2 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cba1c_row0_col4, #T_cba1c_row0_col5, #T_cba1c_row0_col6, #T_cba1c_row0_col7 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_cba1c_row1_col0, #T_cba1c_row1_col1, #T_cba1c_row1_col3 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_cba1c_row1_col4, #T_cba1c_row1_col5, #T_cba1c_row1_col6, #T_cba1c_row1_col7 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
2024-11-09 10:46:39 +04:00
"<table id=\"T_cba1c\">\n",
2024-11-08 22:14:23 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_cba1c_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_cba1c_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_cba1c_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_cba1c_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_cba1c_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_cba1c_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_cba1c_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_cba1c_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" <th class=\"blank col7\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_cba1c_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_cba1c_row0_col0\" class=\"data row0 col0\" >0.977169</td>\n",
" <td id=\"T_cba1c_row0_col1\" class=\"data row0 col1\" >0.666667</td>\n",
" <td id=\"T_cba1c_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_cba1c_row0_col3\" class=\"data row0 col3\" >0.777778</td>\n",
" <td id=\"T_cba1c_row0_col4\" class=\"data row0 col4\" >0.991857</td>\n",
" <td id=\"T_cba1c_row0_col5\" class=\"data row0 col5\" >0.785714</td>\n",
" <td id=\"T_cba1c_row0_col6\" class=\"data row0 col6\" >0.988453</td>\n",
" <td id=\"T_cba1c_row0_col7\" class=\"data row0 col7\" >0.717949</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_cba1c_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_cba1c_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_cba1c_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_cba1c_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_cba1c_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_cba1c_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_cba1c_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_cba1c_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_cba1c_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-11-09 10:46:39 +04:00
"<pandas.io.formats.style.Styler at 0x203bd82cda0>"
2024-11-08 22:14:23 +04:00
]
},
2024-11-09 10:46:39 +04:00
"execution_count": 283,
2024-11-08 22:14:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 284,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-11-09 10:46:39 +04:00
"#T_c86af_row0_col0, #T_c86af_row0_col1 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_c86af_row0_col2, #T_c86af_row0_col3, #T_c86af_row0_col4 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_c86af_row1_col0, #T_c86af_row1_col1 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_c86af_row1_col2, #T_c86af_row1_col3, #T_c86af_row1_col4 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
2024-11-09 10:46:39 +04:00
"<table id=\"T_c86af\">\n",
2024-11-08 22:14:23 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_c86af_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_c86af_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_c86af_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_c86af_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_c86af_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_c86af_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_c86af_row0_col0\" class=\"data row0 col0\" >0.785714</td>\n",
" <td id=\"T_c86af_row0_col1\" class=\"data row0 col1\" >0.717949</td>\n",
" <td id=\"T_c86af_row0_col2\" class=\"data row0 col2\" >0.867222</td>\n",
" <td id=\"T_c86af_row0_col3\" class=\"data row0 col3\" >0.546816</td>\n",
" <td id=\"T_c86af_row0_col4\" class=\"data row0 col4\" >0.551041</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_c86af_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_c86af_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_c86af_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_c86af_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_c86af_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_c86af_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-11-09 10:46:39 +04:00
"<pandas.io.formats.style.Styler at 0x203b5ef2d20>"
2024-11-08 22:14:23 +04:00
]
},
2024-11-09 10:46:39 +04:00
"execution_count": 284,
2024-11-08 22:14:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 285,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
2024-11-09 10:46:39 +04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA3MAAAGxCAYAAADI9u/sAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABV8klEQVR4nO3de1xUdf7H8fcgchGYQUxBFBTTvJTmrRS7mC6FdtNkK812NbM2b612sdzykl0wf2u6lmmlQW6Z5VZmVpZRWpaaWlqWmamFpqBpgGBcZOb3h+vULKggB4aZ7+v5eJzHOuecOfMZY3n7Pd/LsblcLpcAAAAAAD4lwNsFAAAAAAAqj8YcAAAAAPggGnMAAAAA4INozAEAAACAD6IxBwAAAAA+iMYcAAAAAPggGnMAAAAA4INozAEAAACAD6IxBwAAAAA+iMYcAAAAAPggGnMAAAAAUAkff/yxrrnmGsXGxspms2np0qUex10ulyZNmqTGjRsrNDRUSUlJ2rFjh8c5hw8f1uDBg2W32xUZGalbb71V+fn5laqDxhwAAAAAVEJBQYHOP/98zZkzp9zj06dP1+zZszVv3jytX79eYWFhSk5OVmFhofucwYMH65tvvtHKlSu1fPlyffzxx7r99tsrVYfN5XK5qvRNAAAAAMBQNptNb7zxhvr37y/peK9cbGys7r77bt1zzz2SpNzcXEVHRys9PV0DBw7Utm3b1K5dO23YsEFdu3aVJK1YsUJXXnml9u7dq9jY2Ap9dmC1fCMAQK1TWFio4uJiy64XFBSkkJAQy64HAEBlWJ1rLpdLNpvNY19wcLCCg4MrdZ3du3crKytLSUlJ7n0Oh0PdunXT2rVrNXDgQK1du1aRkZHuhpwkJSUlKSAgQOvXr9d1111Xoc+iMQcABigsLFRCs3BlHSi17JoxMTHavXs3DToAQI2rjlwLDw8vM2dt8uTJmjJlSqWuk5WVJUmKjo722B8dHe0+lpWVpUaNGnkcDwwMVFRUlPuciqAxBwAGKC4uVtaBUu3e1Ez2iKpPl8474lRCl59UXFxMYw4AUOOqK9f27Nkju93u3l/ZXrmaRmMOAAxijwiwJPQAAKgNrM41u93u0Zg7EzExMZKk7OxsNW7c2L0/OztbHTt2dJ9z4MABj/cdO3ZMhw8fdr+/Ikh0ADBIqctp2QYAgLfVxlxLSEhQTEyMMjIy3Pvy8vK0fv16JSYmSpISExOVk5OjTZs2uc/58MMP5XQ61a1btwp/Fj1zAGAQp1xyquqLGFtxDQAAqspbuZafn68ffvjB/Xr37t3avHmzoqKiFB8fr7Fjx+qRRx5Rq1atlJCQoIkTJyo2Nta94mXbtm3Vp08f3XbbbZo3b55KSko0evRoDRw4sMIrWUo05gAAAACgUjZu3KhevXq5X991112SpCFDhig9PV3jx49XQUGBbr/9duXk5Ojiiy/WihUrPOaZv/TSSxo9erT+9Kc/KSAgQCkpKZo9e3al6uA5cwBggLy8PDkcDu3b3tSyieKxrfcqNze3ynMLAACoLHLtOHrmAMAgpS6XSi24h2fFNQAAqCrTc40FUAAAAADAB9EzBwAGYQEUAIA/MT3XaMwBgEGccqnU4NADAPgX03ONYZYAAAAA4IPomQMAg5g+HAUA4F9MzzV65gAAAADAB9EzBwAGMX0JZwCAfzE912jMAYBBnP/drLgOAADeZnquMcwSAAAAAHwQPXMAYJBSi5ZwtuIaAABUlem5RmMOAAxS6jq+WXEdAAC8zfRcY5glAAAAAPggeuYAwCCmTxQHAPgX03ONxhwAGMQpm0pls+Q6AAB4m+m5xjBLAAAAAPBB9MwBgEGcruObFdcBAMDbTM81euYAAAAAwAfRMwcABim1aG6BFdcAAKCqTM81GnMAYBDTQw8A4F9MzzWGWQIAAACAD6JnDgAM4nTZ5HRZsISzBdcAAKCqTM81GnMAYBDTh6MAAPyL6bnGMEsAAAAA8EH0zAGAQUoVoFIL7uOVWlALAABVZXqu0ZgDAIO4LJpb4PLRuQUAAP9ieq4xzBIAAAAAfBA9cwBgENMnigMA/IvpuUZjDgAMUuoKUKnLgrkFLguKAQCgikzPNYZZAgAAAIAPomcOAAzilE1OC+7jOeWjtzABAH7F9FyjZw4AAAAAfBA9cwBgENMnigMA/IvpuUZjDgAMYt1Ecd8cjgIA8C+m5xrDLAEAAADAB9EzBwAGOT5RvOpDSay4BgAAVWV6rtGYAwCDOBWgUoNX/QIA+BfTc41hlgAAAADgg+iZAwCDmD5RHADgX0zPNRpzAGAQpwKMfrgqAMC/mJ5rDLMEAAAAAB9EzxwAGKTUZVOpy4KHq1pwDQAAqsr0XKNnDgAAAAB8EI05ADBI6X+XcLZiq6jmzZvLZrOV2UaNGiVJKiws1KhRo9SgQQOFh4crJSVF2dnZ1fVXAADwI97ItdrEN6sGAJwRpyvAsq2iNmzYoP3797u3lStXSpKuv/56SdK4ceP01ltvacmSJVq9erX27dunAQMGVMv3BwD4F2/kWm3CnDkAwBnLy8vzeB0cHKzg4GCPfQ0bNvR4PW3aNJ199tnq2bOncnNztWDBAi1atEi9e/eWJKWlpalt27Zat26dunfvXr1fAAAAH+abTVAAwBmxejhKXFycHA6He0tNTT3l5xcXF+vFF1/UsGHDZLPZtGnTJpWUlCgpKcl9Tps2bRQfH6+1a9dW698FAMD3mT7Mkp45ADCIU9as2OX87//u2bNHdrvdvf9/e+X+19KlS5WTk6OhQ4dKkrKyshQUFKTIyEiP86Kjo5WVlVXlOgEA/s3qXPM1NOYAAGfMbrd7NOZOZ8GCBerbt69iY2OrsSoAAMxAYw4ADOJUgJwWDCU5k2v89NNP+uCDD/T666+798XExKi4uFg5OTkevXPZ2dmKiYmpcp0AAP/mzVyrDXyzagDAGSl1BVi2VVZaWpoaNWqkq666yr2vS5cuqlu3rjIyMtz7tm/frszMTCUmJlrynQEA/subuVYb0DMHAKh2TqdTaWlpGjJkiAIDf48eh8OhW2+9VXfddZeioqJkt9s1ZswYJSYmspIlAACnQWMOAAzilE1OWTFRvHLX+OCDD5SZmalhw4aVOTZz5kwFBAQoJSVFRUVFSk5O1tNPP13lGgEA/s9buVZb0JgDAINYNZSkste44oor5HK5yj0WEhKiOXPmaM6cOVWuCwBgFm/lWm3hm1UDAAAAgOHomQMAg1j1YFRffbgqAMC/mJ5rvlk1AAAAABiOnrka5nQ6tW/fPkVERMhm882JlgBqlsvl0pEjRxQbG6uAgKrdg3O6bHK6LJgobsE14B/INQCVRa5Zh8ZcDdu3b5/i4uK8XQYAH7Rnzx41bdq0StdwWjQcxVcfrgrrkWsAzhS5VnU05mpYRESEJOmnL5rLHu6bPzSoPikDUrxdAmqhY6VF+njbbPfvD6A2IddwKted097bJaAWOqYSrdE75JoFaMzVsBNDUOzhAbJHEHrwFFgn2NsloBazYgib0xUgpwXLL1txDfgHcg2nEmir6+0SUBv990k15FrV0ZgDAIOUyqZSCx6MasU1AACoKtNzzTeboAAAAABgOHrmAMAgpg9HAQD4F9NzjcYcABikVNYMJSmteikAAFSZ6bnmm01QAAAAADAcPXMAYBDTh6MAAPyL6bnmm1UDAAAAgOHomQMAg5S6AlRqwd1HK64BAEBVmZ5rvlk1AOCMuGST04LN5aPP4wEA+Bdv5FppaakmTpyohIQEhYaG6uyzz9bDDz8sl8v1e10ulyZNmqTGjRsrNDRUSUlJ2rFjh+Xfn8YcAAAAAFTQ448/rrlz5+qpp57Stm3b9Pjjj2v69Ol68skn3edMnz5ds2fP1rx587R+/XqFhYUpOTlZhYWFltbCMEsAMIjpw1EAAP7FG7n22WefqV+/frrqqqskSc2bN9fLL7+szz//XNLxXrlZs2bpwQcfVL9+/SRJCxcuVHR0tJYuXaqBAwdWud4TSGMAMIjTZbNsAwDA26zOtby8PI+tqKiozGf26NFDGRkZ+v777yVJW7Zs0Zo1a9S
2024-11-08 22:14:23 +04:00
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"\n",
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Healthy\", \"Sick\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-11-09 10:46:39 +04:00
"В желтом квадрате мы видим значение 79, что обозначает количество правильно классифицированных объектов, отнесенных к классу \"Sick\". Это свидетельствует о том, что модель успешно идентифицирует объекты этого класса, минимизируя количество ложных положительных срабатываний.\n",
2024-11-08 22:14:23 +04:00
"\n",
2024-11-09 10:46:39 +04:00
"В зеленом квадрате значение 42 указывает на количество правильно классифицированных объектов, отнесенных к классу \"Healthy\". Это также является показателем хорошей точности модели в определении объектов данного класса."
2024-11-08 22:14:23 +04:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Определение достижимого уровня качества модели для второй задачи"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подготовка данных"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 286,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin \\\n",
"count 768.000000 768.000000 768.000000 768.000000 768.000000 \n",
"mean 3.845052 120.894531 69.105469 20.536458 79.799479 \n",
"std 3.369578 31.972618 19.355807 15.952218 115.244002 \n",
"min 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
"25% 1.000000 99.000000 62.000000 0.000000 0.000000 \n",
"50% 3.000000 117.000000 72.000000 23.000000 30.500000 \n",
"75% 6.000000 140.250000 80.000000 32.000000 127.250000 \n",
"max 17.000000 199.000000 122.000000 99.000000 846.000000 \n",
"\n",
" BMI DiabetesPedigreeFunction Age Outcome \n",
"count 768.000000 768.000000 768.000000 768.000000 \n",
"mean 31.992578 0.471876 33.240885 0.348958 \n",
"std 7.884160 0.331329 11.760232 0.476951 \n",
"min 0.000000 0.078000 21.000000 0.000000 \n",
"25% 27.300000 0.243750 24.000000 0.000000 \n",
"50% 32.000000 0.372500 29.000000 0.000000 \n",
"75% 36.600000 0.626250 41.000000 1.000000 \n",
"max 67.100000 2.420000 81.000000 1.000000 \n"
]
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import set_config\n",
"\n",
"\n",
"random_state = 9\n",
"set_config(transform_output=\"pandas\")\n",
"df = pd.read_csv(\".//scv//diabetes.csv\")\n",
"print(df.describe())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование выборок"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 287,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>196</th>\n",
" <td>1</td>\n",
" <td>105</td>\n",
" <td>58</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>24.3</td>\n",
" <td>0.187</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69</th>\n",
" <td>4</td>\n",
" <td>146</td>\n",
" <td>85</td>\n",
" <td>27</td>\n",
" <td>100</td>\n",
" <td>28.9</td>\n",
" <td>0.189</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>494</th>\n",
" <td>3</td>\n",
" <td>80</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.174</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>463</th>\n",
" <td>5</td>\n",
" <td>88</td>\n",
" <td>78</td>\n",
" <td>30</td>\n",
" <td>0</td>\n",
" <td>27.6</td>\n",
" <td>0.258</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>653</th>\n",
" <td>2</td>\n",
" <td>120</td>\n",
" <td>54</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>26.8</td>\n",
" <td>0.455</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>322</th>\n",
" <td>0</td>\n",
" <td>124</td>\n",
" <td>70</td>\n",
" <td>20</td>\n",
" <td>0</td>\n",
" <td>27.4</td>\n",
" <td>0.254</td>\n",
" <td>36</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>109</th>\n",
" <td>0</td>\n",
" <td>95</td>\n",
" <td>85</td>\n",
" <td>25</td>\n",
" <td>36</td>\n",
" <td>37.4</td>\n",
" <td>0.247</td>\n",
" <td>24</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>1</td>\n",
" <td>97</td>\n",
" <td>66</td>\n",
" <td>15</td>\n",
" <td>140</td>\n",
" <td>23.2</td>\n",
" <td>0.487</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>651</th>\n",
" <td>1</td>\n",
" <td>117</td>\n",
" <td>60</td>\n",
" <td>23</td>\n",
" <td>106</td>\n",
" <td>33.8</td>\n",
" <td>0.466</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>197</th>\n",
" <td>3</td>\n",
" <td>107</td>\n",
" <td>62</td>\n",
" <td>13</td>\n",
" <td>48</td>\n",
" <td>22.9</td>\n",
" <td>0.678</td>\n",
" <td>23</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"196 1 105 58 0 0 24.3 \n",
"69 4 146 85 27 100 28.9 \n",
"494 3 80 0 0 0 0.0 \n",
"463 5 88 78 30 0 27.6 \n",
"653 2 120 54 0 0 26.8 \n",
".. ... ... ... ... ... ... \n",
"322 0 124 70 20 0 27.4 \n",
"109 0 95 85 25 36 37.4 \n",
"27 1 97 66 15 140 23.2 \n",
"651 1 117 60 23 106 33.8 \n",
"197 3 107 62 13 48 22.9 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"196 0.187 21 0 \n",
"69 0.189 27 0 \n",
"494 0.174 22 0 \n",
"463 0.258 37 0 \n",
"653 0.455 27 0 \n",
".. ... ... ... \n",
"322 0.254 36 1 \n",
"109 0.247 24 1 \n",
"27 0.487 22 0 \n",
"651 0.466 27 0 \n",
"197 0.678 23 1 \n",
"\n",
"[614 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>196</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>494</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>463</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>653</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>322</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>109</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>651</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>197</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Outcome\n",
"196 0\n",
"69 0\n",
"494 0\n",
"463 0\n",
"653 0\n",
".. ...\n",
"322 1\n",
"109 1\n",
"27 0\n",
"651 0\n",
"197 1\n",
"\n",
"[614 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>669</th>\n",
" <td>9</td>\n",
" <td>154</td>\n",
" <td>78</td>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>30.9</td>\n",
" <td>0.164</td>\n",
" <td>45</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>379</th>\n",
" <td>0</td>\n",
" <td>93</td>\n",
" <td>100</td>\n",
" <td>39</td>\n",
" <td>72</td>\n",
" <td>43.4</td>\n",
" <td>1.021</td>\n",
" <td>35</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>640</th>\n",
" <td>0</td>\n",
" <td>102</td>\n",
" <td>86</td>\n",
" <td>17</td>\n",
" <td>105</td>\n",
" <td>29.3</td>\n",
" <td>0.695</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>658</th>\n",
" <td>11</td>\n",
" <td>127</td>\n",
" <td>106</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>39.0</td>\n",
" <td>0.190</td>\n",
" <td>51</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>304</th>\n",
" <td>3</td>\n",
" <td>150</td>\n",
" <td>76</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>21.0</td>\n",
" <td>0.207</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203</th>\n",
" <td>2</td>\n",
" <td>99</td>\n",
" <td>70</td>\n",
" <td>16</td>\n",
" <td>44</td>\n",
" <td>20.4</td>\n",
" <td>0.235</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>605</th>\n",
" <td>1</td>\n",
" <td>124</td>\n",
" <td>60</td>\n",
" <td>32</td>\n",
" <td>0</td>\n",
" <td>35.8</td>\n",
" <td>0.514</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>561</th>\n",
" <td>0</td>\n",
" <td>198</td>\n",
" <td>66</td>\n",
" <td>32</td>\n",
" <td>274</td>\n",
" <td>41.3</td>\n",
" <td>0.502</td>\n",
" <td>28</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>280</th>\n",
" <td>0</td>\n",
" <td>146</td>\n",
" <td>70</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>37.9</td>\n",
" <td>0.334</td>\n",
" <td>28</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>103</th>\n",
" <td>1</td>\n",
" <td>81</td>\n",
" <td>72</td>\n",
" <td>18</td>\n",
" <td>40</td>\n",
" <td>26.6</td>\n",
" <td>0.283</td>\n",
" <td>24</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>154 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"669 9 154 78 30 100 30.9 \n",
"379 0 93 100 39 72 43.4 \n",
"640 0 102 86 17 105 29.3 \n",
"658 11 127 106 0 0 39.0 \n",
"304 3 150 76 0 0 21.0 \n",
".. ... ... ... ... ... ... \n",
"203 2 99 70 16 44 20.4 \n",
"605 1 124 60 32 0 35.8 \n",
"561 0 198 66 32 274 41.3 \n",
"280 0 146 70 0 0 37.9 \n",
"103 1 81 72 18 40 26.6 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"669 0.164 45 0 \n",
"379 1.021 35 0 \n",
"640 0.695 27 0 \n",
"658 0.190 51 0 \n",
"304 0.207 37 0 \n",
".. ... ... ... \n",
"203 0.235 27 0 \n",
"605 0.514 21 0 \n",
"561 0.502 28 1 \n",
"280 0.334 28 1 \n",
"103 0.283 24 0 \n",
"\n",
"[154 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>669</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>379</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>640</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>658</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>304</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>605</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>561</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>280</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>103</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>154 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Outcome\n",
"669 0\n",
"379 0\n",
"640 0\n",
"658 0\n",
"304 0\n",
".. ...\n",
"203 0\n",
"605 0\n",
"561 1\n",
"280 1\n",
"103 0\n",
"\n",
"[154 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input: DataFrame,\n",
" stratify_colname: str = \"y\",\n",
" frac_train: float = 0.6,\n",
" frac_val: float = 0.15,\n",
" frac_test: float = 0.25,\n",
" random_state: int = None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
"\n",
" if not (0 < frac_train < 1) or not (0 <= frac_val <= 1) or not (0 <= frac_test <= 1):\n",
" raise ValueError(\"Fractions must be between 0 and 1 and the sum must equal 1.\")\n",
" \n",
" if not (frac_train + frac_val + frac_test == 1.0):\n",
" raise ValueError(\"fractions %f, %f, %f do not add up to 1.0\" %\n",
" (frac_train, frac_val, frac_test))\n",
"\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(f\"{stratify_colname} is not a column in the DataFrame.\")\n",
"\n",
" X = df_input\n",
" y = df_input[[stratify_colname]]\n",
"\n",
" \n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
"\n",
" if frac_val == 0:\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
"\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
"\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
"\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" \n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"Outcome\", frac_train=0.80, frac_val=0.0, frac_test=0.20, random_state=random_state\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование конвейера для классификации данных"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 288,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"class DiabetFeatures(BaseEstimator, TransformerMixin):\n",
" def __init__(self):\n",
" pass\n",
" def fit(self, X, y=None):\n",
" return self\n",
" \n",
"\n",
2024-11-09 10:46:39 +04:00
"columns_to_drop = [\"Pregnancies\", \"SkinThickness\", \"Insulin\", \"BMI\", \"Outcome\"]\n",
"num_columns = [\"Glucose\", \"Age\", \"BloodPressure\", \"DiabetesPedigreeFunction\"]\n",
2024-11-08 22:14:23 +04:00
"cat_columns = []\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"features_postprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_cat\", preprocessing_cat, [\"Cabin_type\"]),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Демонстрация работы конвейера"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 289,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Glucose</th>\n",
" <th>Age</th>\n",
" <th>BloodPressure</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>196</th>\n",
" <td>-0.478144</td>\n",
" <td>-1.029257</td>\n",
" <td>-0.554050</td>\n",
" <td>-0.849205</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69</th>\n",
" <td>0.818506</td>\n",
" <td>-0.522334</td>\n",
" <td>0.804885</td>\n",
" <td>-0.843172</td>\n",
" </tr>\n",
" <tr>\n",
" <th>494</th>\n",
" <td>-1.268784</td>\n",
" <td>-0.944770</td>\n",
" <td>-3.473244</td>\n",
" <td>-0.888421</td>\n",
" </tr>\n",
" <tr>\n",
" <th>463</th>\n",
" <td>-1.015779</td>\n",
" <td>0.322537</td>\n",
" <td>0.452568</td>\n",
" <td>-0.635028</td>\n",
" </tr>\n",
" <tr>\n",
" <th>653</th>\n",
" <td>-0.003760</td>\n",
" <td>-0.522334</td>\n",
" <td>-0.755374</td>\n",
" <td>-0.040763</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>322</th>\n",
" <td>0.122742</td>\n",
" <td>0.238050</td>\n",
" <td>0.049921</td>\n",
" <td>-0.647095</td>\n",
" </tr>\n",
" <tr>\n",
" <th>109</th>\n",
" <td>-0.794400</td>\n",
" <td>-0.775796</td>\n",
" <td>0.804885</td>\n",
" <td>-0.668211</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>-0.731149</td>\n",
" <td>-0.944770</td>\n",
" <td>-0.151403</td>\n",
" <td>0.055767</td>\n",
" </tr>\n",
" <tr>\n",
" <th>651</th>\n",
" <td>-0.098637</td>\n",
" <td>-0.522334</td>\n",
" <td>-0.453388</td>\n",
" <td>-0.007581</td>\n",
" </tr>\n",
" <tr>\n",
" <th>197</th>\n",
" <td>-0.414893</td>\n",
" <td>-0.860283</td>\n",
" <td>-0.352726</td>\n",
" <td>0.631933</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
2024-11-09 10:46:39 +04:00
"<p>614 rows × 4 columns</p>\n",
2024-11-08 22:14:23 +04:00
"</div>"
],
"text/plain": [
2024-11-09 10:46:39 +04:00
" Glucose Age BloodPressure DiabetesPedigreeFunction\n",
"196 -0.478144 -1.029257 -0.554050 -0.849205\n",
"69 0.818506 -0.522334 0.804885 -0.843172\n",
"494 -1.268784 -0.944770 -3.473244 -0.888421\n",
"463 -1.015779 0.322537 0.452568 -0.635028\n",
"653 -0.003760 -0.522334 -0.755374 -0.040763\n",
".. ... ... ... ...\n",
"322 0.122742 0.238050 0.049921 -0.647095\n",
"109 -0.794400 -0.775796 0.804885 -0.668211\n",
"27 -0.731149 -0.944770 -0.151403 0.055767\n",
"651 -0.098637 -0.522334 -0.453388 -0.007581\n",
"197 -0.414893 -0.860283 -0.352726 0.631933\n",
2024-11-08 22:14:23 +04:00
"\n",
2024-11-09 10:46:39 +04:00
"[614 rows x 4 columns]"
2024-11-08 22:14:23 +04:00
]
},
2024-11-09 10:46:39 +04:00
"execution_count": 289,
2024-11-08 22:14:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование набора моделей для классификации"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 290,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [],
"source": [
2024-11-09 09:14:53 +04:00
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
2024-11-08 22:14:23 +04:00
"\n",
2024-11-09 09:14:53 +04:00
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
2024-11-08 22:14:23 +04:00
" \"decision_tree\": {\n",
2024-11-09 09:14:53 +04:00
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
2024-11-08 22:14:23 +04:00
" },\n",
" \"random_forest\": {\n",
2024-11-09 09:14:53 +04:00
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
2024-11-08 22:14:23 +04:00
" )\n",
" },\n",
" \"mlp\": {\n",
2024-11-09 09:14:53 +04:00
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
2024-11-08 22:14:23 +04:00
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение моделей на обучающем наборе данных и оценка на тестовом¶"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 291,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-11-09 10:46:39 +04:00
"Model: logistic\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-11-08 22:14:23 +04:00
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Сводная таблица оценок качества для использованных моделей классификации¶\n",
"\n",
"Матрица неточностей\n"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 292,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
2024-11-09 10:46:39 +04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0kAAAQ9CAYAAACMbQYZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxUZd8G8OuwDQgMiwIDCsjiruRWiruEorkmT5bZm3u576VWKu5mi6ShlRloaqaVpJaampp77rnihoKyaCAgKNvMef8gJycYnIGBmTNc3+dzPo9znzP3/Ib0XNznPosgiqIIIiIiIiIiAgBYGLsAIiIiIiIiU8JBEhERERER0VM4SCIiIiIiInoKB0lERERERERP4SCJiIiIiIjoKRwkERERERERPYWDJCIiIiIioqdwkERERERERPQUDpKIiIiIiIiewkESGV1MTAwEQcCtW7cqpP9bt25BEATExMQYpL/9+/dDEATs37/fIP0RERGZk4iICAiCoNO2giAgIiKiYgsiKgMOkoi0WLFihcEGVkREREQkHVbGLoCoovn6+uLx48ewtrbW630rVqxAjRo1MHjwYI32Dh064PHjx7CxsTFglURERObhgw8+wPTp041dBlG5cJBEZk8QBNja2hqsPwsLC4P2R0REZC5ycnJgb28PKyv+iknSxtPtyCStWLECjRo1gkwmg5eXF8aMGYOMjIxi20VFRcHf3x92dnZ44YUXcPDgQXTq1AmdOnVSb1PSNUkpKSkYMmQIatWqBZlMBk9PT/Tp00d9XVTt2rVx8eJFHDhwAIIgQBAEdZ/arkk6fvw4XnrpJbi4uMDe3h5BQUH47LPPDPuDISIiMhFPrj26dOkSXn/9dbi4uKBdu3YlXpOUl5eHSZMmwc3NDY6Ojujduzfu3LlTYr/79+9Hy5YtYWtri4CAAHz55Zdar3Nat24dWrRoATs7O7i6uuK1115DYmJihXxfqlo4zCeTExERgTlz5iA0NBSjRo1CXFwcVq5ciRMnTuDw4cPq0+ZWrlyJsWPHon379pg0aRJu3bqFvn37wsXFBbVq1Sr1M8LDw3Hx4kWMGzcOtWvXxr1797B7924kJCSgdu3aiIyMxLhx4+Dg4ID3338fAODh4aG1v927d6Nnz57w9PTEhAkToFAocPnyZWzfvh0TJkww3A+HiIjIxLzyyiuoU6cOFi5cCFEUce/evWLbDB8+HOvWrcPrr7+ONm3a4Pfff0ePHj2KbXfmzBl069YNnp6emDNnDpRKJebOnQs3N7di2y5YsAAzZ85E//79MXz4cNy/fx/Lly9Hhw4dcObMGTg7O1fE16WqQiQysujoaBGAGB8fL967d0+0sbERu3btKiqVSvU2n3/+uQhA/Oabb0RRFMW8vDyxevXq4vPPPy8WFBSot4uJiREBiB07dlS3xcfHiwDE6OhoURRF8cGDByIA8aOPPiq1rkaNGmn088S+fftEAOK+fftEURTFwsJC0c/PT/T19RUfPHigsa1KpdL9B0FERCQhs2fPFgGIAwYMKLH9ibNnz4oAxNGjR2ts9/rrr4sAxNmzZ6vbevXqJVarVk28e/euuu3atWuilZWVRp+3bt0SLS0txQULFmj0ef78edHKyqpYO5G+eLodmZQ9e/YgPz8fEydOhIXFv389R4wYAblcjl9++QUAcPLkSaSlpWHEiBEa5z0PHDgQLi4upX6GnZ0dbGxssH//fjx48KDcNZ85cwbx8fGYOHFisaNWut4ClYiISKpGjhxZ6vpff/0VADB+/HiN9okTJ2q8ViqV2LNnD/r27QsvLy91e2BgILp3766x7U8//QSVSoX+/fvj77//Vi8KhQJ16tTBvn37yvGNiHi6HZmY27dvAwDq1aun0W5jYwN/f3/1+if/HxgYqLGdlZUVateuXepnyGQyfPjhh5gyZQo8PDzQunVr9OzZE2+++SYUCoXeNd+4cQMA0LhxY73fS0REJHV+fn6lrr99+zYsLCwQEBCg0f7frL937x4eP35cLNuB4nl/7do1iKKIOnXqlPiZ+t7Rlui/OEiiKmnixIno1asXYmNjsWvXLsycOROLFi3C77//jmbNmhm7PCIiIsmws7Or9M9UqVQQBAE7duyApaVlsfUODg6VXhOZF55uRybF19cXABAXF6fRnp+fj/j4ePX6J/9//fp1je0KCwvVd6h7loCAAEyZMgW//fYbLly4gPz8fHzyySfq9bqeKvfkyNiFCxd02p6IiKgq8fX1hUqlUp958cR/s97d3R22trbFsh0onvcBAQEQRRF+fn4IDQ0ttrRu3drwX4SqFA6SyKSEhobCxsYGy5YtgyiK6vbVq1cjMzNTfSecli1bonr16li1ahUKCwvV261fv/6Z1xk9evQIubm5Gm0BAQFwdHREXl6eus3e3r7E247/V/PmzeHn54fIyMhi2z/9HYiIiKqiJ9cTLVu2TKM9MjJS47WlpSVCQ0MRGxuLpKQkdfv169exY8cOjW379esHS0tLzJkzp1jWiqKItLQ0A34Dqop4uh2ZFDc3N8yYMQNz5sxBt27d0Lt3b8TFxWHFihV4/vnn8cYbbwAoukYpIiIC48aNQ0hICPr3749bt24hJiYGAQEBpc4CXb16FS+++CL69++Phg0bwsrKClu2bEFqaipee+019XYtWrTAypUrMX/+fAQGBsLd3R0hISHF+rOwsMDKlSvRq1cvNG3aFEOGDIGnpyeuXLmCixcvYteuXYb/QREREUlE06ZNMWDAAKxYsQKZmZlo06YN9u7dW+KMUUREBH777Te0bdsWo0aNglKpxOeff47GjRvj7Nmz6u0CAgIwf/58zJgxQ/0IEEdHR8THx2PLli146623MHXq1Er8lmRuOEgikxMREQE3Nzd8/vnnmDRpElxdXfHWW29h4cKFGhdijh07FqIo4pNPPsHUqVPx3HPPYevWrRg/fjxsbW219u/t7Y0BAwZg7969+Pbbb2FlZYX69etj06ZNCA8PV283a9Ys3L59G0uWLMHDhw/RsWPHEgdJABAWFoZ9+/Zhzpw5+OSTT6BSqRAQEIARI0YY7gdDREQkUd988w3c3Nywfv16xMbGIiQkBL/88gu8vb01tmvRogV27NiBqVOnYubMmfD29sbcuXNx+fJlXLlyRWPb6dOno27duli6dCnmzJkDoCjju3btit69e1fadyPzJIg8H4jMiEqlgpubG/r164dVq1YZuxwiIiIygL59++LixYu4du2asUuhKoLXJJFk5ebmFjsPee3atUhPT0enTp2MUxQRERGVy+PHjzVeX7t2Db/++iuznSoVZ5JIsvbv349JkybhlVdeQfXq1XH69GmsXr0aDRo0wKlTp2BjY2PsEomIiEhPnp6eGDx4sPr5iCtXrkReXh7OnDmj9blIRIbGa5JIsmrXrg1vb28sW7YM6enpcHV1xZtvvonFixdzgERERCRR3bp1w3fffYeUlBTIZDIEBwdj4cKFHCBRpeJMEhERERERSULt2rVx+/btYu2jR49GVFQUcnNzMWXKFGzcuBF5eXkICwvDihUr4OHhodfncJBERERERESScP/+fSiVSvXrCxcuoEuXLti3bx86deqEUaNG4ZdffkFMTAycnJwwduxYWFhY4PDhw3p9DgdJREREREQkSRMnTsT27dtx7do1ZGVlwc3NDRs2bMD//vc/AMCVK1fQoEEDHD16FK1bt9a5X16TVMlUKhWSkpLg6OhY6gNPicyRKIp4+PAhvLy8YGFh+Jtr5ubmIj8/v9RtbGxsSn2OFhFVPcxmquoqMp91yWZRFIv925PJZJDJZKW+Lz8/H+vWrcPkyZMhCAJOnTqFgoIChIaGqrepX78+fHx8OEgydUlJScUenEZU1SQmJqJWrVoG7TM3Nxd+vg5IuacsdTuFQoH4+HgOlIhIjdlMVMTQ+axrNjs4OCA7O1ujbfbs2YiIiCj1fbGxscjIyMDgwYMBACkpKbCxsYGzs7PGdh4eHkhJSdGrdg6SKpmjoyMA4Pbp2pA78DFVxvBy3SbGLqHKKkQBDuFX9b8DQ8rPz0fKPSWun/S
2024-11-08 22:14:23 +04:00
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
2024-11-09 09:14:53 +04:00
" confusion_matrix=c_matrix, display_labels=[\"Healthy\", \"Sick\"]\n",
2024-11-08 22:14:23 +04:00
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 293,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-11-09 10:46:39 +04:00
"#T_0c53d_row0_col0 {\n",
" background-color: #20a486;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row0_col1, #T_0c53d_row3_col0, #T_0c53d_row3_col2, #T_0c53d_row5_col3 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_0c53d_row0_col2 {\n",
" background-color: #24868e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row0_col3, #T_0c53d_row3_col1 {\n",
" background-color: #2fb47c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row0_col4 {\n",
" background-color: #7b02a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row0_col5, #T_0c53d_row0_col7, #T_0c53d_row3_col4, #T_0c53d_row3_col6 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_0c53d_row0_col6 {\n",
" background-color: #6700a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row1_col0 {\n",
" background-color: #23a983;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row1_col1, #T_0c53d_row2_col0 {\n",
" background-color: #90d743;\n",
2024-11-08 22:14:23 +04:00
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_0c53d_row1_col2, #T_0c53d_row6_col3, #T_0c53d_row7_col0, #T_0c53d_row7_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row1_col3, #T_0c53d_row2_col3, #T_0c53d_row4_col3, #T_0c53d_row5_col1, #T_0c53d_row7_col3 {\n",
" background-color: #1e9b8a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row1_col4 {\n",
" background-color: #7d03a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row1_col5 {\n",
" background-color: #cb4679;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row1_col6 {\n",
" background-color: #6300a7;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row1_col7 {\n",
" background-color: #bb3488;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row2_col1 {\n",
" background-color: #56c667;\n",
2024-11-08 22:14:23 +04:00
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_0c53d_row2_col2 {\n",
" background-color: #67cc5c;\n",
" color: #000000;\n",
"}\n",
"#T_0c53d_row2_col4 {\n",
" background-color: #ca457a;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_0c53d_row2_col5 {\n",
" background-color: #b83289;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_0c53d_row2_col6 {\n",
" background-color: #c9447a;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_0c53d_row2_col7 {\n",
" background-color: #a72197;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row3_col3 {\n",
" background-color: #50c46a;\n",
" color: #000000;\n",
"}\n",
"#T_0c53d_row3_col5 {\n",
" background-color: #b22b8f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row3_col7, #T_0c53d_row5_col7 {\n",
" background-color: #c13b82;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row4_col0 {\n",
" background-color: #29af7f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row4_col1 {\n",
" background-color: #28ae80;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row4_col2 {\n",
" background-color: #1f9e89;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row4_col4 {\n",
" background-color: #910ea3;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row4_col5 {\n",
" background-color: #9d189d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row4_col6 {\n",
" background-color: #8707a6;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row4_col7, #T_0c53d_row6_col5 {\n",
" background-color: #8d0ba5;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row5_col0 {\n",
" background-color: #1f948c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row5_col2 {\n",
" background-color: #22a785;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row5_col4 {\n",
" background-color: #7401a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row5_col5 {\n",
" background-color: #9511a1;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row5_col6 {\n",
" background-color: #7801a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row6_col0 {\n",
" background-color: #58c765;\n",
" color: #000000;\n",
"}\n",
"#T_0c53d_row6_col1 {\n",
" background-color: #24aa83;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row6_col2 {\n",
" background-color: #32b67a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row6_col4 {\n",
" background-color: #b12a90;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row6_col6 {\n",
" background-color: #ac2694;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row6_col7 {\n",
" background-color: #6001a6;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0c53d_row7_col2 {\n",
" background-color: #23898e;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_0c53d_row7_col4, #T_0c53d_row7_col5, #T_0c53d_row7_col6, #T_0c53d_row7_col7 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
2024-11-09 10:46:39 +04:00
"<table id=\"T_0c53d\">\n",
2024-11-08 22:14:23 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_0c53d_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_0c53d_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_0c53d_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_0c53d_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_0c53d_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_0c53d_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_0c53d_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_0c53d_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_0c53d_level0_row0\" class=\"row_heading level0 row0\" >naive_bayes</th>\n",
" <td id=\"T_0c53d_row0_col0\" class=\"data row0 col0\" >0.678571</td>\n",
" <td id=\"T_0c53d_row0_col1\" class=\"data row0 col1\" >0.734694</td>\n",
" <td id=\"T_0c53d_row0_col2\" class=\"data row0 col2\" >0.532710</td>\n",
" <td id=\"T_0c53d_row0_col3\" class=\"data row0 col3\" >0.666667</td>\n",
" <td id=\"T_0c53d_row0_col4\" class=\"data row0 col4\" >0.749186</td>\n",
" <td id=\"T_0c53d_row0_col5\" class=\"data row0 col5\" >0.798701</td>\n",
" <td id=\"T_0c53d_row0_col6\" class=\"data row0 col6\" >0.596859</td>\n",
" <td id=\"T_0c53d_row0_col7\" class=\"data row0 col7\" >0.699029</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0c53d_level0_row1\" class=\"row_heading level0 row1\" >logistic</th>\n",
" <td id=\"T_0c53d_row1_col0\" class=\"data row1 col0\" >0.696774</td>\n",
" <td id=\"T_0c53d_row1_col1\" class=\"data row1 col1\" >0.717391</td>\n",
" <td id=\"T_0c53d_row1_col2\" class=\"data row1 col2\" >0.504673</td>\n",
" <td id=\"T_0c53d_row1_col3\" class=\"data row1 col3\" >0.611111</td>\n",
" <td id=\"T_0c53d_row1_col4\" class=\"data row1 col4\" >0.750814</td>\n",
" <td id=\"T_0c53d_row1_col5\" class=\"data row1 col5\" >0.779221</td>\n",
" <td id=\"T_0c53d_row1_col6\" class=\"data row1 col6\" >0.585366</td>\n",
" <td id=\"T_0c53d_row1_col7\" class=\"data row1 col7\" >0.660000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0c53d_level0_row2\" class=\"row_heading level0 row2\" >gradient_boosting</th>\n",
" <td id=\"T_0c53d_row2_col0\" class=\"data row2 col0\" >0.949749</td>\n",
" <td id=\"T_0c53d_row2_col1\" class=\"data row2 col1\" >0.673469</td>\n",
" <td id=\"T_0c53d_row2_col2\" class=\"data row2 col2\" >0.883178</td>\n",
" <td id=\"T_0c53d_row2_col3\" class=\"data row2 col3\" >0.611111</td>\n",
" <td id=\"T_0c53d_row2_col4\" class=\"data row2 col4\" >0.942997</td>\n",
" <td id=\"T_0c53d_row2_col5\" class=\"data row2 col5\" >0.759740</td>\n",
" <td id=\"T_0c53d_row2_col6\" class=\"data row2 col6\" >0.915254</td>\n",
" <td id=\"T_0c53d_row2_col7\" class=\"data row2 col7\" >0.640777</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0c53d_level0_row3\" class=\"row_heading level0 row3\" >random_forest</th>\n",
" <td id=\"T_0c53d_row3_col0\" class=\"data row3 col0\" >0.990741</td>\n",
" <td id=\"T_0c53d_row3_col1\" class=\"data row3 col1\" >0.633333</td>\n",
" <td id=\"T_0c53d_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_0c53d_row3_col3\" class=\"data row3 col3\" >0.703704</td>\n",
" <td id=\"T_0c53d_row3_col4\" class=\"data row3 col4\" >0.996743</td>\n",
" <td id=\"T_0c53d_row3_col5\" class=\"data row3 col5\" >0.753247</td>\n",
" <td id=\"T_0c53d_row3_col6\" class=\"data row3 col6\" >0.995349</td>\n",
" <td id=\"T_0c53d_row3_col7\" class=\"data row3 col7\" >0.666667</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0c53d_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
" <td id=\"T_0c53d_row4_col0\" class=\"data row4 col0\" >0.730159</td>\n",
" <td id=\"T_0c53d_row4_col1\" class=\"data row4 col1\" >0.622642</td>\n",
" <td id=\"T_0c53d_row4_col2\" class=\"data row4 col2\" >0.644860</td>\n",
" <td id=\"T_0c53d_row4_col3\" class=\"data row4 col3\" >0.611111</td>\n",
" <td id=\"T_0c53d_row4_col4\" class=\"data row4 col4\" >0.793160</td>\n",
" <td id=\"T_0c53d_row4_col5\" class=\"data row4 col5\" >0.733766</td>\n",
" <td id=\"T_0c53d_row4_col6\" class=\"data row4 col6\" >0.684864</td>\n",
" <td id=\"T_0c53d_row4_col7\" class=\"data row4 col7\" >0.616822</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0c53d_level0_row5\" class=\"row_heading level0 row5\" >ridge</th>\n",
" <td id=\"T_0c53d_row5_col0\" class=\"data row5 col0\" >0.602459</td>\n",
" <td id=\"T_0c53d_row5_col1\" class=\"data row5 col1\" >0.583333</td>\n",
" <td id=\"T_0c53d_row5_col2\" class=\"data row5 col2\" >0.686916</td>\n",
" <td id=\"T_0c53d_row5_col3\" class=\"data row5 col3\" >0.777778</td>\n",
" <td id=\"T_0c53d_row5_col4\" class=\"data row5 col4\" >0.732899</td>\n",
" <td id=\"T_0c53d_row5_col5\" class=\"data row5 col5\" >0.727273</td>\n",
" <td id=\"T_0c53d_row5_col6\" class=\"data row5 col6\" >0.641921</td>\n",
" <td id=\"T_0c53d_row5_col7\" class=\"data row5 col7\" >0.666667</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0c53d_level0_row6\" class=\"row_heading level0 row6\" >decision_tree</th>\n",
" <td id=\"T_0c53d_row6_col0\" class=\"data row6 col0\" >0.848168</td>\n",
" <td id=\"T_0c53d_row6_col1\" class=\"data row6 col1\" >0.612245</td>\n",
" <td id=\"T_0c53d_row6_col2\" class=\"data row6 col2\" >0.757009</td>\n",
" <td id=\"T_0c53d_row6_col3\" class=\"data row6 col3\" >0.555556</td>\n",
" <td id=\"T_0c53d_row6_col4\" class=\"data row6 col4\" >0.868078</td>\n",
" <td id=\"T_0c53d_row6_col5\" class=\"data row6 col5\" >0.720779</td>\n",
" <td id=\"T_0c53d_row6_col6\" class=\"data row6 col6\" >0.800000</td>\n",
" <td id=\"T_0c53d_row6_col7\" class=\"data row6 col7\" >0.582524</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0c53d_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_0c53d_row7_col0\" class=\"data row7 col0\" >0.513158</td>\n",
" <td id=\"T_0c53d_row7_col1\" class=\"data row7 col1\" >0.532258</td>\n",
" <td id=\"T_0c53d_row7_col2\" class=\"data row7 col2\" >0.546729</td>\n",
" <td id=\"T_0c53d_row7_col3\" class=\"data row7 col3\" >0.611111</td>\n",
" <td id=\"T_0c53d_row7_col4\" class=\"data row7 col4\" >0.661238</td>\n",
" <td id=\"T_0c53d_row7_col5\" class=\"data row7 col5\" >0.675325</td>\n",
" <td id=\"T_0c53d_row7_col6\" class=\"data row7 col6\" >0.529412</td>\n",
" <td id=\"T_0c53d_row7_col7\" class=\"data row7 col7\" >0.568966</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-11-09 10:46:39 +04:00
"<pandas.io.formats.style.Styler at 0x203b5bc8b60>"
2024-11-08 22:14:23 +04:00
]
},
2024-11-09 10:46:39 +04:00
"execution_count": 293,
2024-11-08 22:14:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"\n",
"Почти все модели, включая логистическую регрессию, ридж-регрессию, KNN, наивный байесовский классификатор, многослойную перцептронную сеть, случайный лес, дерево решений и градиентный бустинг, демонстрируют 100% точность (1.000000) на обучающей выборке. Это указывает на то, что модели смогли подстроиться под обучающие данные, что может указывать на возможное переобучение.\n",
"\n",
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса\n"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 294,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-11-09 10:46:39 +04:00
"#T_40ce9_row0_col0 {\n",
" background-color: #7cd250;\n",
" color: #000000;\n",
"}\n",
"#T_40ce9_row0_col1 {\n",
" background-color: #58c765;\n",
2024-11-08 22:14:23 +04:00
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_40ce9_row0_col2, #T_40ce9_row2_col3, #T_40ce9_row2_col4 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_40ce9_row0_col3 {\n",
" background-color: #c5407e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row0_col4 {\n",
" background-color: #c6417d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row1_col0 {\n",
" background-color: #26ad81;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row1_col1, #T_40ce9_row4_col1 {\n",
" background-color: #63cb5f;\n",
2024-11-08 22:14:23 +04:00
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_40ce9_row1_col2 {\n",
" background-color: #d9586a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row1_col3 {\n",
" background-color: #a82296;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row1_col4 {\n",
" background-color: #b02991;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row2_col0, #T_40ce9_row2_col1 {\n",
" background-color: #a8db34;\n",
2024-11-08 22:14:23 +04:00
" color: #000000;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_40ce9_row2_col2 {\n",
" background-color: #d5546e;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_40ce9_row3_col0 {\n",
" background-color: #54c568;\n",
" color: #000000;\n",
"}\n",
"#T_40ce9_row3_col1 {\n",
" background-color: #38b977;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row3_col2 {\n",
" background-color: #d14e72;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row3_col3, #T_40ce9_row3_col4 {\n",
" background-color: #b22b8f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row4_col0 {\n",
" background-color: #48c16e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row4_col2 {\n",
" background-color: #cb4679;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row4_col3, #T_40ce9_row4_col4 {\n",
" background-color: #b7318a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row5_col0 {\n",
" background-color: #2db27d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row5_col1, #T_40ce9_row6_col0 {\n",
" background-color: #22a785;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row5_col2 {\n",
" background-color: #a62098;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row5_col3 {\n",
" background-color: #9613a1;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row5_col4 {\n",
" background-color: #9511a1;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row6_col1 {\n",
" background-color: #228b8d;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_40ce9_row6_col2, #T_40ce9_row7_col2, #T_40ce9_row7_col3, #T_40ce9_row7_col4 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_40ce9_row6_col3, #T_40ce9_row6_col4 {\n",
" background-color: #7b02a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_40ce9_row7_col0, #T_40ce9_row7_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-08 22:14:23 +04:00
"</style>\n",
2024-11-09 10:46:39 +04:00
"<table id=\"T_40ce9\">\n",
2024-11-08 22:14:23 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_40ce9_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_40ce9_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_40ce9_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_40ce9_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_40ce9_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_40ce9_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_40ce9_row0_col0\" class=\"data row0 col0\" >0.779221</td>\n",
" <td id=\"T_40ce9_row0_col1\" class=\"data row0 col1\" >0.660000</td>\n",
" <td id=\"T_40ce9_row0_col2\" class=\"data row0 col2\" >0.825370</td>\n",
" <td id=\"T_40ce9_row0_col3\" class=\"data row0 col3\" >0.498083</td>\n",
" <td id=\"T_40ce9_row0_col4\" class=\"data row0 col4\" >0.501593</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_40ce9_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_40ce9_row1_col0\" class=\"data row1 col0\" >0.727273</td>\n",
" <td id=\"T_40ce9_row1_col1\" class=\"data row1 col1\" >0.666667</td>\n",
" <td id=\"T_40ce9_row1_col2\" class=\"data row1 col2\" >0.824444</td>\n",
" <td id=\"T_40ce9_row1_col3\" class=\"data row1 col3\" >0.443756</td>\n",
" <td id=\"T_40ce9_row1_col4\" class=\"data row1 col4\" >0.456930</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_40ce9_level0_row2\" class=\"row_heading level0 row2\" >naive_bayes</th>\n",
" <td id=\"T_40ce9_row2_col0\" class=\"data row2 col0\" >0.798701</td>\n",
" <td id=\"T_40ce9_row2_col1\" class=\"data row2 col1\" >0.699029</td>\n",
" <td id=\"T_40ce9_row2_col2\" class=\"data row2 col2\" >0.820556</td>\n",
" <td id=\"T_40ce9_row2_col3\" class=\"data row2 col3\" >0.548344</td>\n",
" <td id=\"T_40ce9_row2_col4\" class=\"data row2 col4\" >0.549805</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_40ce9_level0_row3\" class=\"row_heading level0 row3\" >gradient_boosting</th>\n",
" <td id=\"T_40ce9_row3_col0\" class=\"data row3 col0\" >0.759740</td>\n",
" <td id=\"T_40ce9_row3_col1\" class=\"data row3 col1\" >0.640777</td>\n",
" <td id=\"T_40ce9_row3_col2\" class=\"data row3 col2\" >0.815741</td>\n",
" <td id=\"T_40ce9_row3_col3\" class=\"data row3 col3\" >0.460927</td>\n",
" <td id=\"T_40ce9_row3_col4\" class=\"data row3 col4\" >0.462155</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_40ce9_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
" <td id=\"T_40ce9_row4_col0\" class=\"data row4 col0\" >0.753247</td>\n",
" <td id=\"T_40ce9_row4_col1\" class=\"data row4 col1\" >0.666667</td>\n",
" <td id=\"T_40ce9_row4_col2\" class=\"data row4 col2\" >0.808704</td>\n",
" <td id=\"T_40ce9_row4_col3\" class=\"data row4 col3\" >0.471650</td>\n",
" <td id=\"T_40ce9_row4_col4\" class=\"data row4 col4\" >0.473300</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_40ce9_level0_row5\" class=\"row_heading level0 row5\" >knn</th>\n",
" <td id=\"T_40ce9_row5_col0\" class=\"data row5 col0\" >0.733766</td>\n",
" <td id=\"T_40ce9_row5_col1\" class=\"data row5 col1\" >0.616822</td>\n",
" <td id=\"T_40ce9_row5_col2\" class=\"data row5 col2\" >0.776204</td>\n",
" <td id=\"T_40ce9_row5_col3\" class=\"data row5 col3\" >0.412870</td>\n",
" <td id=\"T_40ce9_row5_col4\" class=\"data row5 col4\" >0.412912</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_40ce9_level0_row6\" class=\"row_heading level0 row6\" >decision_tree</th>\n",
" <td id=\"T_40ce9_row6_col0\" class=\"data row6 col0\" >0.720779</td>\n",
" <td id=\"T_40ce9_row6_col1\" class=\"data row6 col1\" >0.582524</td>\n",
" <td id=\"T_40ce9_row6_col2\" class=\"data row6 col2\" >0.719167</td>\n",
" <td id=\"T_40ce9_row6_col3\" class=\"data row6 col3\" >0.373510</td>\n",
" <td id=\"T_40ce9_row6_col4\" class=\"data row6 col4\" >0.374505</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_40ce9_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_40ce9_row7_col0\" class=\"data row7 col0\" >0.675325</td>\n",
" <td id=\"T_40ce9_row7_col1\" class=\"data row7 col1\" >0.568966</td>\n",
" <td id=\"T_40ce9_row7_col2\" class=\"data row7 col2\" >0.719074</td>\n",
" <td id=\"T_40ce9_row7_col3\" class=\"data row7 col3\" >0.310530</td>\n",
" <td id=\"T_40ce9_row7_col4\" class=\"data row7 col4\" >0.312437</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-11-09 10:46:39 +04:00
"<pandas.io.formats.style.Styler at 0x203be18ef60>"
2024-11-08 22:14:23 +04:00
]
},
2024-11-09 10:46:39 +04:00
"execution_count": 294,
2024-11-08 22:14:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 295,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-11-09 10:46:39 +04:00
"'naive_bayes'"
2024-11-08 22:14:23 +04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 296,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-11-09 10:46:39 +04:00
"'Error items count: 31'"
2024-11-08 22:14:23 +04:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Predicted</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
2024-11-09 10:46:39 +04:00
" <tr>\n",
" <th>64</th>\n",
" <td>7</td>\n",
" <td>0</td>\n",
" <td>114</td>\n",
" <td>66</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>32.8</td>\n",
" <td>0.258</td>\n",
" <td>42</td>\n",
" <td>1</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th>88</th>\n",
" <td>15</td>\n",
" <td>0</td>\n",
" <td>136</td>\n",
" <td>70</td>\n",
" <td>32</td>\n",
" <td>110</td>\n",
" <td>37.1</td>\n",
" <td>0.153</td>\n",
" <td>43</td>\n",
" <td>1</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
2024-11-09 10:46:39 +04:00
" <tr>\n",
" <th>125</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>88</td>\n",
" <td>30</td>\n",
" <td>42</td>\n",
" <td>99</td>\n",
" <td>55.0</td>\n",
" <td>0.496</td>\n",
" <td>26</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>143</th>\n",
" <td>10</td>\n",
" <td>0</td>\n",
" <td>108</td>\n",
" <td>66</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>32.4</td>\n",
" <td>0.272</td>\n",
" <td>42</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>170</th>\n",
" <td>6</td>\n",
" <td>0</td>\n",
" <td>102</td>\n",
" <td>82</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30.8</td>\n",
" <td>0.180</td>\n",
" <td>36</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>188</th>\n",
" <td>8</td>\n",
" <td>0</td>\n",
" <td>109</td>\n",
" <td>76</td>\n",
" <td>39</td>\n",
" <td>114</td>\n",
" <td>27.9</td>\n",
" <td>0.640</td>\n",
" <td>31</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>199</th>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>148</td>\n",
" <td>60</td>\n",
" <td>27</td>\n",
" <td>318</td>\n",
" <td>30.9</td>\n",
" <td>0.150</td>\n",
" <td>29</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>214</th>\n",
" <td>9</td>\n",
" <td>0</td>\n",
" <td>112</td>\n",
" <td>82</td>\n",
" <td>32</td>\n",
" <td>175</td>\n",
" <td>34.2</td>\n",
" <td>0.260</td>\n",
" <td>36</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>223</th>\n",
" <td>7</td>\n",
" <td>1</td>\n",
" <td>142</td>\n",
" <td>60</td>\n",
" <td>33</td>\n",
" <td>190</td>\n",
" <td>28.8</td>\n",
" <td>0.687</td>\n",
" <td>61</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>228</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>197</td>\n",
" <td>70</td>\n",
" <td>39</td>\n",
" <td>744</td>\n",
" <td>36.7</td>\n",
" <td>2.329</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>280</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>146</td>\n",
" <td>70</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>37.9</td>\n",
" <td>0.334</td>\n",
" <td>28</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>294</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>161</td>\n",
" <td>50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>21.9</td>\n",
" <td>0.254</td>\n",
" <td>65</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>304</th>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>150</td>\n",
" <td>76</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>21.0</td>\n",
" <td>0.207</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>309</th>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>124</td>\n",
" <td>68</td>\n",
" <td>28</td>\n",
" <td>205</td>\n",
" <td>32.9</td>\n",
" <td>0.875</td>\n",
" <td>30</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>335</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>165</td>\n",
" <td>76</td>\n",
" <td>43</td>\n",
" <td>255</td>\n",
" <td>47.9</td>\n",
" <td>0.259</td>\n",
" <td>26</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>395</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>127</td>\n",
" <td>58</td>\n",
" <td>24</td>\n",
" <td>275</td>\n",
" <td>27.7</td>\n",
" <td>1.600</td>\n",
" <td>25</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>397</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>131</td>\n",
" <td>66</td>\n",
" <td>40</td>\n",
" <td>0</td>\n",
" <td>34.3</td>\n",
" <td>0.196</td>\n",
" <td>22</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>401</th>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>137</td>\n",
" <td>61</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>24.2</td>\n",
" <td>0.151</td>\n",
" <td>55</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>406</th>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>115</td>\n",
" <td>72</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>28.9</td>\n",
" <td>0.376</td>\n",
" <td>46</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>510</th>\n",
" <td>12</td>\n",
" <td>0</td>\n",
" <td>84</td>\n",
" <td>72</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" <td>29.7</td>\n",
" <td>0.297</td>\n",
" <td>46</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>541</th>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>128</td>\n",
" <td>72</td>\n",
" <td>25</td>\n",
" <td>190</td>\n",
" <td>32.4</td>\n",
" <td>0.549</td>\n",
" <td>27</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>549</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>189</td>\n",
" <td>110</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" <td>28.5</td>\n",
" <td>0.680</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>568</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>154</td>\n",
" <td>72</td>\n",
" <td>29</td>\n",
" <td>126</td>\n",
" <td>31.3</td>\n",
" <td>0.338</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>577</th>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>118</td>\n",
" <td>80</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>42.9</td>\n",
" <td>0.693</td>\n",
" <td>21</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>622</th>\n",
" <td>6</td>\n",
" <td>1</td>\n",
" <td>183</td>\n",
" <td>94</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40.8</td>\n",
" <td>1.461</td>\n",
" <td>45</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>630</th>\n",
" <td>7</td>\n",
" <td>0</td>\n",
" <td>114</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>27.4</td>\n",
" <td>0.732</td>\n",
" <td>34</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>658</th>\n",
" <td>11</td>\n",
" <td>1</td>\n",
" <td>127</td>\n",
" <td>106</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>39.0</td>\n",
" <td>0.190</td>\n",
" <td>51</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>669</th>\n",
" <td>9</td>\n",
" <td>1</td>\n",
" <td>154</td>\n",
" <td>78</td>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>30.9</td>\n",
" <td>0.164</td>\n",
" <td>45</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>693</th>\n",
" <td>7</td>\n",
" <td>0</td>\n",
" <td>129</td>\n",
" <td>68</td>\n",
" <td>49</td>\n",
" <td>125</td>\n",
" <td>38.5</td>\n",
" <td>0.439</td>\n",
" <td>43</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>730</th>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>130</td>\n",
" <td>78</td>\n",
" <td>23</td>\n",
" <td>79</td>\n",
" <td>28.4</td>\n",
" <td>0.323</td>\n",
" <td>34</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>744</th>\n",
" <td>13</td>\n",
" <td>1</td>\n",
" <td>153</td>\n",
" <td>88</td>\n",
" <td>37</td>\n",
" <td>140</td>\n",
" <td>40.6</td>\n",
" <td>1.174</td>\n",
" <td>39</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Predicted Glucose BloodPressure SkinThickness Insulin \\\n",
"64 7 0 114 66 0 0 \n",
"88 15 0 136 70 32 110 \n",
"125 1 0 88 30 42 99 \n",
"143 10 0 108 66 0 0 \n",
"170 6 0 102 82 0 0 \n",
"188 8 0 109 76 39 114 \n",
"199 4 0 148 60 27 318 \n",
"214 9 0 112 82 32 175 \n",
"223 7 1 142 60 33 190 \n",
"228 4 1 197 70 39 744 \n",
"280 0 0 146 70 0 0 \n",
"294 0 1 161 50 0 0 \n",
"304 3 1 150 76 0 0 \n",
"309 2 0 124 68 28 205 \n",
"335 0 1 165 76 43 255 \n",
"395 2 1 127 58 24 275 \n",
"397 0 0 131 66 40 0 \n",
"401 6 1 137 61 0 0 \n",
"406 4 0 115 72 0 0 \n",
"510 12 0 84 72 31 0 \n",
"541 3 0 128 72 25 190 \n",
"549 4 1 189 110 31 0 \n",
"568 4 1 154 72 29 126 \n",
"577 2 0 118 80 0 0 \n",
"622 6 1 183 94 0 0 \n",
"630 7 0 114 64 0 0 \n",
"658 11 1 127 106 0 0 \n",
"669 9 1 154 78 30 100 \n",
"693 7 0 129 68 49 125 \n",
"730 3 0 130 78 23 79 \n",
"744 13 1 153 88 37 140 \n",
"\n",
" BMI DiabetesPedigreeFunction Age Outcome \n",
"64 32.8 0.258 42 1 \n",
"88 37.1 0.153 43 1 \n",
"125 55.0 0.496 26 1 \n",
"143 32.4 0.272 42 1 \n",
"170 30.8 0.180 36 1 \n",
"188 27.9 0.640 31 1 \n",
"199 30.9 0.150 29 1 \n",
"214 34.2 0.260 36 1 \n",
"223 28.8 0.687 61 0 \n",
"228 36.7 2.329 31 0 \n",
"280 37.9 0.334 28 1 \n",
"294 21.9 0.254 65 0 \n",
"304 21.0 0.207 37 0 \n",
"309 32.9 0.875 30 1 \n",
"335 47.9 0.259 26 0 \n",
"395 27.7 1.600 25 0 \n",
"397 34.3 0.196 22 1 \n",
"401 24.2 0.151 55 0 \n",
"406 28.9 0.376 46 1 \n",
"510 29.7 0.297 46 1 \n",
"541 32.4 0.549 27 1 \n",
"549 28.5 0.680 37 0 \n",
"568 31.3 0.338 37 0 \n",
"577 42.9 0.693 21 1 \n",
"622 40.8 1.461 45 0 \n",
"630 27.4 0.732 34 1 \n",
"658 39.0 0.190 51 0 \n",
"669 30.9 0.164 45 0 \n",
"693 38.5 0.439 43 1 \n",
"730 28.4 0.323 34 1 \n",
"744 40.6 1.174 39 0 "
]
},
"execution_count": 296,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"Outcome\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Пример использования обученной модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
"execution_count": 297,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>555</th>\n",
" <td>7.0</td>\n",
" <td>124.0</td>\n",
" <td>70.0</td>\n",
" <td>33.0</td>\n",
" <td>215.0</td>\n",
" <td>25.5</td>\n",
" <td>0.161</td>\n",
" <td>37.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
2024-11-08 22:14:23 +04:00
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"555 7.0 124.0 70.0 33.0 215.0 25.5 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"555 0.161 37.0 0.0 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Glucose</th>\n",
" <th>Age</th>\n",
" <th>BloodPressure</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>555</th>\n",
" <td>0.122742</td>\n",
" <td>0.322537</td>\n",
" <td>0.049921</td>\n",
" <td>-0.927636</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
2024-11-09 10:46:39 +04:00
" Glucose Age BloodPressure DiabetesPedigreeFunction\n",
"555 0.122742 0.322537 0.049921 -0.927636"
2024-11-08 22:14:23 +04:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
2024-11-09 10:46:39 +04:00
"'predicted: 0 (proba: [0.7669925 0.2330075])'"
2024-11-08 22:14:23 +04:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 0'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 555\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 298,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-11-09 10:46:39 +04:00
"{'model__criterion': 'entropy',\n",
" 'model__max_depth': 7,\n",
2024-11-08 22:14:23 +04:00
" 'model__max_features': 'sqrt',\n",
2024-11-09 10:46:39 +04:00
" 'model__n_estimators': 50}"
2024-11-08 22:14:23 +04:00
]
},
2024-11-09 10:46:39 +04:00
"execution_count": 298,
2024-11-08 22:14:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"optimized_model_type = \"random_forest\"\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 50, 100],\n",
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
" \"model__max_depth\": [5, 7, 10],\n",
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 299,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [],
"source": [
"optimized_model = ensemble.RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"log2\",\n",
" n_estimators=10,\n",
")\n",
"\n",
"result = {}\n",
"\n",
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 300,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 301,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-11-09 10:46:39 +04:00
"#T_dbff0_row0_col0, #T_dbff0_row0_col2, #T_dbff0_row0_col3, #T_dbff0_row1_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_dbff0_row0_col1, #T_dbff0_row1_col0, #T_dbff0_row1_col2, #T_dbff0_row1_col3 {\n",
" background-color: #26818e;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_dbff0_row0_col4, #T_dbff0_row0_col6, #T_dbff0_row0_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_dbff0_row0_col5, #T_dbff0_row1_col5 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_dbff0_row1_col4, #T_dbff0_row1_col6, #T_dbff0_row1_col7 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-08 22:14:23 +04:00
"</style>\n",
2024-11-09 10:46:39 +04:00
"<table id=\"T_dbff0\">\n",
2024-11-08 22:14:23 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_dbff0_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_dbff0_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_dbff0_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_dbff0_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_dbff0_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_dbff0_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_dbff0_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_dbff0_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" <th class=\"blank col7\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_dbff0_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_dbff0_row0_col0\" class=\"data row0 col0\" >0.990741</td>\n",
" <td id=\"T_dbff0_row0_col1\" class=\"data row0 col1\" >0.633333</td>\n",
" <td id=\"T_dbff0_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_dbff0_row0_col3\" class=\"data row0 col3\" >0.703704</td>\n",
" <td id=\"T_dbff0_row0_col4\" class=\"data row0 col4\" >0.996743</td>\n",
" <td id=\"T_dbff0_row0_col5\" class=\"data row0 col5\" >0.753247</td>\n",
" <td id=\"T_dbff0_row0_col6\" class=\"data row0 col6\" >0.995349</td>\n",
" <td id=\"T_dbff0_row0_col7\" class=\"data row0 col7\" >0.666667</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_dbff0_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_dbff0_row1_col0\" class=\"data row1 col0\" >0.861842</td>\n",
" <td id=\"T_dbff0_row1_col1\" class=\"data row1 col1\" >0.673913</td>\n",
" <td id=\"T_dbff0_row1_col2\" class=\"data row1 col2\" >0.612150</td>\n",
" <td id=\"T_dbff0_row1_col3\" class=\"data row1 col3\" >0.574074</td>\n",
" <td id=\"T_dbff0_row1_col4\" class=\"data row1 col4\" >0.830619</td>\n",
" <td id=\"T_dbff0_row1_col5\" class=\"data row1 col5\" >0.753247</td>\n",
" <td id=\"T_dbff0_row1_col6\" class=\"data row1 col6\" >0.715847</td>\n",
" <td id=\"T_dbff0_row1_col7\" class=\"data row1 col7\" >0.620000</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-11-09 10:46:39 +04:00
"<pandas.io.formats.style.Styler at 0x203c01dc0e0>"
2024-11-08 22:14:23 +04:00
]
},
2024-11-09 10:46:39 +04:00
"execution_count": 301,
2024-11-08 22:14:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 302,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-11-09 10:46:39 +04:00
"#T_efa8e_row0_col0, #T_efa8e_row1_col0 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_efa8e_row0_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_efa8e_row0_col2, #T_efa8e_row1_col3, #T_efa8e_row1_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_efa8e_row0_col3, #T_efa8e_row0_col4, #T_efa8e_row1_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_efa8e_row1_col1 {\n",
" background-color: #26818e;\n",
2024-11-08 22:14:23 +04:00
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
2024-11-09 10:46:39 +04:00
"<table id=\"T_efa8e\">\n",
2024-11-08 22:14:23 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_efa8e_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_efa8e_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_efa8e_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_efa8e_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_efa8e_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_efa8e_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_efa8e_row0_col0\" class=\"data row0 col0\" >0.753247</td>\n",
" <td id=\"T_efa8e_row0_col1\" class=\"data row0 col1\" >0.666667</td>\n",
" <td id=\"T_efa8e_row0_col2\" class=\"data row0 col2\" >0.808704</td>\n",
" <td id=\"T_efa8e_row0_col3\" class=\"data row0 col3\" >0.471650</td>\n",
" <td id=\"T_efa8e_row0_col4\" class=\"data row0 col4\" >0.473300</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_efa8e_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_efa8e_row1_col0\" class=\"data row1 col0\" >0.753247</td>\n",
" <td id=\"T_efa8e_row1_col1\" class=\"data row1 col1\" >0.620000</td>\n",
" <td id=\"T_efa8e_row1_col2\" class=\"data row1 col2\" >0.846111</td>\n",
" <td id=\"T_efa8e_row1_col3\" class=\"data row1 col3\" >0.439034</td>\n",
" <td id=\"T_efa8e_row1_col4\" class=\"data row1 col4\" >0.442128</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-11-09 10:46:39 +04:00
"<pandas.io.formats.style.Styler at 0x203bd7f7530>"
2024-11-08 22:14:23 +04:00
]
},
2024-11-09 10:46:39 +04:00
"execution_count": 302,
2024-11-08 22:14:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 303,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
2024-11-09 10:46:39 +04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2oAAAGjCAYAAABdU+ZeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABYiUlEQVR4nO3df3zNdf/H8efZZj/sx5mJzbQNKT+KhGLyI8KSisuulHR9CV3lVyEllR9RJtdVpKRCU1dJuSqhqx+sKPkRihKt8qPNj40L24z2w875/iGnzjW0Y5/tnH0+j/vt9rldzufzOe/zPq7Zs/fn/fq8Pzan0+kUAAAAAMBn+Hm7AwAAAAAAdwzUAAAAAMDHMFADAAAAAB/DQA0AAAAAfAwDNQAAAADwMQzUAAAAAMDHMFADAAAAAB8T4O0OAAAqR0FBgYqKigxrLzAwUMHBwYa1BwCAJ8yeawzUAMACCgoKVD8hTFmHSgxrMyYmRnv27PGpUAMAWIMVco2BGgBYQFFRkbIOleiXLfUUEV7+qve84w4ltNqroqIinwk0AIB1WCHXGKgBgIWEhdsUFm4rdzsOlb8NAADKy8y5xkANACykxOlQidOYdgAA8DYz5xqrPgIAAACAj2FGDQAsxCGnHCr/pUcj2gAAoLzMnGsM1ADAQhxyyIjiDmNaAQCgfMyca5Q+AgAAAICPYUYNACykxOlUibP85R1GtAEAQHmZOdcYqAGAhZi5lh8AYD1mzjVKHwEAAADAxzCjBgAW4pBTJSa98ggAsB4z5xozagAAAADgY5hRAwALMXMtPwDAesycawzUAMBCzLw6FgDAesyca5Q+AgAAAICPYUYNACzE8dtmRDsAAHibmXONgRoAWEiJQatjGdEGAADlZeZco/QRAAAAAHwMM2oAYCElztObEe0AAOBtZs41BmoAYCFmruUHAFiPmXON0kcAAAAA8DHMqAGAhThkU4lshrQDAIC3mTnXGKgBgIU4nKc3I9oBAMDbzJxrlD4CAAAAgI9hoAYAFlLyW4mIERsAAN7mjVwrKSnRhAkTVL9+fYWEhOiSSy7R1KlT5XT+Pi3ndDo1ceJE1alTRyEhIeratat++uknj74bAzUAAAAAKKOnnnpKc+fO1fPPP6+dO3fqqaee0owZM/Tcc8+5zpkxY4Zmz56tF198URs3blRoaKiSkpJUUFBQ5s/hHjUAsBCjZsOYUQMA+AJv5Nq6devUq1cv9ezZU5JUr149vfnmm/rqq68knZ5NmzVrlh577DH16tVLkvTaa68pOjpaS5cu1e23316mz2FGDQAsxOG0GbYBAOBtRudaXl6e21ZYWFjqM9u1a6e0tDT9+OOPkqRt27Zp7dq16tGjhyRpz549ysrKUteuXV3vsdvtatOmjdavX1/m78aMGgAAAABIiouLc3s9adIkTZ482W3fww8/rLy8PDVu3Fj+/v4qKSnRk08+qf79+0uSsrKyJEnR0dFu74uOjnYdKwsGagBgIZQ+AgDMxOhcy8zMVEREhGt/UFBQqXPffvttvfHGG1q0aJEuv/xybd26VaNGjVJsbKwGDBhQ7r6cwUANACykRH4qMaDqvcSAvgAAUF5G51pERITbQO1sHnzwQT388MOue82aNWumX375RSkpKRowYIBiYmIkSdnZ2apTp47rfdnZ2WrRokWZ+8Q9agAAAABQRidPnpSfn/swyt/fXw6HQ5JUv359xcTEKC0tzXU8Ly9PGzduVGJiYpk/hxk1ALAQp0ELgThZTAQA4AO8kWs333yznnzyScXHx+vyyy/XN998o2eeeUaDBg2SJNlsNo0aNUpPPPGELr30UtWvX18TJkxQbGysevfuXebPYaAGABbCPWoAADPxRq4999xzmjBhgoYNG6ZDhw4pNjZW99xzjyZOnOg656GHHtKJEyf097//XTk5OWrfvr0++ugjBQcHl/lzbM4/PkIbAGBKeXl5stvt+uS7BIWGl7/q/cRxh7o3+0W5ubl/WssPAIDRrJBrzKgBgIWUOP1U4jTgpmsu8QEAfICZc43FRAAAAADAxzCjBgAW4pBNDgOu0Tnkg5ceAQCWY+ZcY6AGABbCYiIAADMxc65R+ggAAAAAPoYZNQCwEONuuva9EhEAgPWYOdcYqAGAhZyu5S9/eYcRbQAAUF5mzjVKHwEAAADAxzCjBgAW4pCfSky6OhYAwHrMnGsM1ADAQsxcyw8AsB4z5xqljwAAAADgY5hRAwALccjPtA8GBQBYj5lzjRk1ALCQEqfNsK2s6tWrJ5vNVmobPny4JKmgoEDDhw9XzZo1FRYWpuTkZGVnZ1fUXwEAwES8kWuVhYEaAKBCbdq0SQcPHnRtK1eulCTdeuutkqTRo0dr+fLlWrJkidasWaMDBw6oT58+3uwyAABeR+kjAFhIiUGrY5V4UCJSq1Ytt9fTp0/XJZdcok6dOik3N1cLFizQokWL1KVLF0lSamqqmjRpog0bNqht27bl7isAwLy8kWuVhRk1AMAFy8vLc9sKCwvPe35RUZFef/11DRo0SDabTVu2bFFxcbG6du3qOqdx48aKj4/X+vXrK7r7AAD4LAZqAGAhDqefYZskxcXFyW63u7aUlJTzfv7SpUuVk5OjgQMHSpKysrIUGBioyMhIt/Oio6OVlZVVEX8FAAATMTrXfAmljwBgIUaXiGRmZioiIsK1Pygo6LzvW7BggXr06KHY2Nhy9wEAADOXPjJQAwBcsIiICLeB2vn88ssvWrVqld59913XvpiYGBUVFSknJ8dtVi07O1sxMTFGdxcAgCrD9+b4AAAVxiFjljJ2XMBnp6amqnbt2urZs6drX6tWrVStWjWlpaW59qWnpysjI0OJiYnl/8IAAFPzZq5VNGbUAMBCjHswqGdtOBwOpaamasCAAQoI+D167Ha7Bg8erDFjxigqKkoREREaOXKkEhMTWfERAPCnvJVrlYGBGgCgwq1atUoZGRkaNGhQqWMzZ86Un5+fkpOTVVhYqKSkJL3wwgte6CUAAL6DgRoAWEiJ008lBqxs5Wkb3bt3l9N59hu1g4ODNWfOHM2ZM6fc/QIAWIu3cq0yMFADAAtxyCaHbIa0AwCAt5k513xv6AgAAAAAFseMGgBYiJlLRAAA1mPmXPO9HgEAAACAxTGjBgAWUiI/lRhwjc6INgAAKC8z5xoDtUrmcDh04MABhYeHy2bzvZsWAfgep9Op48ePKzY2Vn5+5QsSh9Mmh9OAm64NaAPmQK4B8BS5VjYM1CrZgQMHFBcX5+1uAKiCMjMzdfHFF3u7G4Abcg3AhSLXzo+BWiULDw+XJP3ydT1FhPneFCu869buN3q7C/BBpxxFWp3xsuv3R3k4DCoRcfhgiQi8g1zD+fzlsmbe7gJ80CkVa63+Q679CQZqlexMWUhEmJ8iwn3vBwLeFeAX5O0uwIcZUVbmcPrJYcDKVka0AXMg13A+AbZq3u4CfJHz9P+Qa+fnez0CAAAAAItjRg0ALKRENpWo/FcwjWgDAIDyMnOuMVADAAsxc4kIAMB6zJxrvtcjAAAAALA4ZtQAwEJKZEx5R0n5uwIAQLmZOdcYqAGAhZi5RAQAYD1mzjXf6xEAAAAAWBwzagBgISVOP5UYcNXQiDYAACgvM+ea7/UIAAAAACyOGTUAsBCnbHIYcNO10wefNwMAsB4z5xoDNQCwEDOXiAAArMfMueZ7PQIAAAAAi2NGDQAsxOG0yeEsf3mHEW0AAFBeZs41BmoAYCEl8lOJAcUURrQBAEB5mTnXfK9HAAAAAOCj6tWrJ5vNVmobPny4JKmgoEDDhw9XzZo1FRYWpuTkZGVnZ3v8OQzUAMBCzpSIGLEBAOBt3si1TZs26eDBg65t5cqVkqRbb71VkjR69GgtX75cS5Ys0Zo1a3TgwAH16dPH4+9G6SMAWIhDfnIYcI3OiDYAACgvb+RarVq13F5Pnz5dl1xyiTp16qTc3FwtWLBAixYtUpcuXSRJqampatKkiTZs2KC2bduW+XNIWgAAAACQlJeX57YVFhae9/yioiK9/vrrGjRokGw2m7Zs2aLi4mJ17drVdU7jxo0VHx+
2024-11-08 22:14:23 +04:00
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
2024-11-09 09:14:53 +04:00
" confusion_matrix=c_matrix, display_labels=[\"Healthy\", \"Sick\"]\n",
2024-11-08 22:14:23 +04:00
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Регрессионная модель"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 304,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['Pregnancies', 'Glucose', 'BloodPressure', 'SkinThickness', 'Insulin',\n",
" 'BMI', 'DiabetesPedigreeFunction', 'Age', 'Outcome'],\n",
" dtype='object')\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>6</td>\n",
" <td>148</td>\n",
" <td>72</td>\n",
" <td>35</td>\n",
" <td>0</td>\n",
" <td>33.6</td>\n",
" <td>0.627</td>\n",
" <td>50</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>85</td>\n",
" <td>66</td>\n",
" <td>29</td>\n",
" <td>0</td>\n",
" <td>26.6</td>\n",
" <td>0.351</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>8</td>\n",
" <td>183</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>23.3</td>\n",
" <td>0.672</td>\n",
" <td>32</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>89</td>\n",
" <td>66</td>\n",
" <td>23</td>\n",
" <td>94</td>\n",
" <td>28.1</td>\n",
" <td>0.167</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>137</td>\n",
" <td>40</td>\n",
" <td>35</td>\n",
" <td>168</td>\n",
" <td>43.1</td>\n",
" <td>2.288</td>\n",
" <td>33</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>763</th>\n",
" <td>10</td>\n",
" <td>101</td>\n",
" <td>76</td>\n",
" <td>48</td>\n",
" <td>180</td>\n",
" <td>32.9</td>\n",
" <td>0.171</td>\n",
" <td>63</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>764</th>\n",
" <td>2</td>\n",
" <td>122</td>\n",
" <td>70</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" <td>36.8</td>\n",
" <td>0.340</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>765</th>\n",
" <td>5</td>\n",
" <td>121</td>\n",
" <td>72</td>\n",
" <td>23</td>\n",
" <td>112</td>\n",
" <td>26.2</td>\n",
" <td>0.245</td>\n",
" <td>30</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>766</th>\n",
" <td>1</td>\n",
" <td>126</td>\n",
" <td>60</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30.1</td>\n",
" <td>0.349</td>\n",
" <td>47</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>767</th>\n",
" <td>1</td>\n",
" <td>93</td>\n",
" <td>70</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" <td>30.4</td>\n",
" <td>0.315</td>\n",
" <td>23</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>768 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"0 6 148 72 35 0 33.6 \n",
"1 1 85 66 29 0 26.6 \n",
"2 8 183 64 0 0 23.3 \n",
"3 1 89 66 23 94 28.1 \n",
"4 0 137 40 35 168 43.1 \n",
".. ... ... ... ... ... ... \n",
"763 10 101 76 48 180 32.9 \n",
"764 2 122 70 27 0 36.8 \n",
"765 5 121 72 23 112 26.2 \n",
"766 1 126 60 0 0 30.1 \n",
"767 1 93 70 31 0 30.4 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"0 0.627 50 1 \n",
"1 0.351 31 0 \n",
"2 0.672 32 1 \n",
"3 0.167 21 0 \n",
"4 2.288 33 1 \n",
".. ... ... ... \n",
"763 0.171 63 0 \n",
"764 0.340 27 0 \n",
"765 0.245 30 0 \n",
"766 0.349 47 1 \n",
"767 0.315 23 0 \n",
"\n",
"[768 rows x 9 columns]"
]
},
2024-11-09 10:46:39 +04:00
"execution_count": 304,
2024-11-08 22:14:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import set_config\n",
"\n",
"random_state=9\n",
"set_config(transform_output=\"pandas\")\n",
"df = pd.read_csv(\".//scv//diabetes.csv\")\n",
"print(df.columns)\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Разделение набора данных на обучающую и тестовые выборки"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 305,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>60</th>\n",
" <td>2</td>\n",
" <td>84</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.304</td>\n",
" <td>21</td>\n",
" </tr>\n",
" <tr>\n",
" <th>618</th>\n",
" <td>9</td>\n",
" <td>112</td>\n",
" <td>82</td>\n",
" <td>24</td>\n",
" <td>0</td>\n",
" <td>28.2</td>\n",
" <td>1.282</td>\n",
" <td>50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>346</th>\n",
" <td>1</td>\n",
" <td>139</td>\n",
" <td>46</td>\n",
" <td>19</td>\n",
" <td>83</td>\n",
" <td>28.7</td>\n",
" <td>0.654</td>\n",
" <td>22</td>\n",
" </tr>\n",
" <tr>\n",
" <th>294</th>\n",
" <td>0</td>\n",
" <td>161</td>\n",
" <td>50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>21.9</td>\n",
" <td>0.254</td>\n",
" <td>65</td>\n",
" </tr>\n",
" <tr>\n",
" <th>231</th>\n",
" <td>6</td>\n",
" <td>134</td>\n",
" <td>80</td>\n",
" <td>37</td>\n",
" <td>370</td>\n",
" <td>46.2</td>\n",
" <td>0.238</td>\n",
" <td>46</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>71</th>\n",
" <td>5</td>\n",
" <td>139</td>\n",
" <td>64</td>\n",
" <td>35</td>\n",
" <td>140</td>\n",
" <td>28.6</td>\n",
" <td>0.411</td>\n",
" <td>26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>106</th>\n",
" <td>1</td>\n",
" <td>96</td>\n",
" <td>122</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>22.4</td>\n",
" <td>0.207</td>\n",
" <td>27</td>\n",
" </tr>\n",
" <tr>\n",
" <th>270</th>\n",
" <td>10</td>\n",
" <td>101</td>\n",
" <td>86</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" <td>45.6</td>\n",
" <td>1.136</td>\n",
" <td>38</td>\n",
" </tr>\n",
" <tr>\n",
" <th>435</th>\n",
" <td>0</td>\n",
" <td>141</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>42.4</td>\n",
" <td>0.205</td>\n",
" <td>29</td>\n",
" </tr>\n",
" <tr>\n",
" <th>102</th>\n",
" <td>0</td>\n",
" <td>125</td>\n",
" <td>96</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>22.5</td>\n",
" <td>0.262</td>\n",
" <td>21</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"60 2 84 0 0 0 0.0 \n",
"618 9 112 82 24 0 28.2 \n",
"346 1 139 46 19 83 28.7 \n",
"294 0 161 50 0 0 21.9 \n",
"231 6 134 80 37 370 46.2 \n",
".. ... ... ... ... ... ... \n",
"71 5 139 64 35 140 28.6 \n",
"106 1 96 122 0 0 22.4 \n",
"270 10 101 86 37 0 45.6 \n",
"435 0 141 0 0 0 42.4 \n",
"102 0 125 96 0 0 22.5 \n",
"\n",
" DiabetesPedigreeFunction Age \n",
"60 0.304 21 \n",
"618 1.282 50 \n",
"346 0.654 22 \n",
"294 0.254 65 \n",
"231 0.238 46 \n",
".. ... ... \n",
"71 0.411 26 \n",
"106 0.207 27 \n",
"270 1.136 38 \n",
"435 0.205 29 \n",
"102 0.262 21 \n",
"\n",
"[614 rows x 8 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>60</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>618</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>346</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>294</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>231</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>71</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>106</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>270</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>435</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>102</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Outcome\n",
"60 0\n",
"618 1\n",
"346 0\n",
"294 0\n",
"231 1\n",
".. ...\n",
"71 0\n",
"106 0\n",
"270 1\n",
"435 1\n",
"102 0\n",
"\n",
"[614 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>668</th>\n",
" <td>6</td>\n",
" <td>98</td>\n",
" <td>58</td>\n",
" <td>33</td>\n",
" <td>190</td>\n",
" <td>34.0</td>\n",
" <td>0.430</td>\n",
" <td>43</td>\n",
" </tr>\n",
" <tr>\n",
" <th>324</th>\n",
" <td>2</td>\n",
" <td>112</td>\n",
" <td>75</td>\n",
" <td>32</td>\n",
" <td>0</td>\n",
" <td>35.7</td>\n",
" <td>0.148</td>\n",
" <td>21</td>\n",
" </tr>\n",
" <tr>\n",
" <th>624</th>\n",
" <td>2</td>\n",
" <td>108</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30.8</td>\n",
" <td>0.158</td>\n",
" <td>21</td>\n",
" </tr>\n",
" <tr>\n",
" <th>690</th>\n",
" <td>8</td>\n",
" <td>107</td>\n",
" <td>80</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>24.6</td>\n",
" <td>0.856</td>\n",
" <td>34</td>\n",
" </tr>\n",
" <tr>\n",
" <th>473</th>\n",
" <td>7</td>\n",
" <td>136</td>\n",
" <td>90</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>29.9</td>\n",
" <td>0.210</td>\n",
" <td>50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>355</th>\n",
" <td>9</td>\n",
" <td>165</td>\n",
" <td>88</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30.4</td>\n",
" <td>0.302</td>\n",
" <td>49</td>\n",
" </tr>\n",
" <tr>\n",
" <th>534</th>\n",
" <td>1</td>\n",
" <td>77</td>\n",
" <td>56</td>\n",
" <td>30</td>\n",
" <td>56</td>\n",
" <td>33.3</td>\n",
" <td>1.251</td>\n",
" <td>24</td>\n",
" </tr>\n",
" <tr>\n",
" <th>344</th>\n",
" <td>8</td>\n",
" <td>95</td>\n",
" <td>72</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>36.8</td>\n",
" <td>0.485</td>\n",
" <td>57</td>\n",
" </tr>\n",
" <tr>\n",
" <th>296</th>\n",
" <td>2</td>\n",
" <td>146</td>\n",
" <td>70</td>\n",
" <td>38</td>\n",
" <td>360</td>\n",
" <td>28.0</td>\n",
" <td>0.337</td>\n",
" <td>29</td>\n",
" </tr>\n",
" <tr>\n",
" <th>462</th>\n",
" <td>8</td>\n",
" <td>74</td>\n",
" <td>70</td>\n",
" <td>40</td>\n",
" <td>49</td>\n",
" <td>35.3</td>\n",
" <td>0.705</td>\n",
" <td>39</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>154 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"668 6 98 58 33 190 34.0 \n",
"324 2 112 75 32 0 35.7 \n",
"624 2 108 64 0 0 30.8 \n",
"690 8 107 80 0 0 24.6 \n",
"473 7 136 90 0 0 29.9 \n",
".. ... ... ... ... ... ... \n",
"355 9 165 88 0 0 30.4 \n",
"534 1 77 56 30 56 33.3 \n",
"344 8 95 72 0 0 36.8 \n",
"296 2 146 70 38 360 28.0 \n",
"462 8 74 70 40 49 35.3 \n",
"\n",
" DiabetesPedigreeFunction Age \n",
"668 0.430 43 \n",
"324 0.148 21 \n",
"624 0.158 21 \n",
"690 0.856 34 \n",
"473 0.210 50 \n",
".. ... ... \n",
"355 0.302 49 \n",
"534 1.251 24 \n",
"344 0.485 57 \n",
"296 0.337 29 \n",
"462 0.705 39 \n",
"\n",
"[154 rows x 8 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>668</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>324</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>624</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>690</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>473</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>355</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>534</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>344</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>296</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>462</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>154 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Outcome\n",
"668 0\n",
"324 0\n",
"624 0\n",
"690 0\n",
"473 0\n",
".. ...\n",
"355 1\n",
"534 0\n",
"344 0\n",
"296 1\n",
"462 0\n",
"\n",
"[154 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_into_train_test(\n",
" df_input: DataFrame,\n",
2024-11-09 10:46:39 +04:00
" target_colname: str = \"Outcome\",\n",
2024-11-08 22:14:23 +04:00
" frac_train: float = 0.8,\n",
" random_state: int = None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if not (0 < frac_train < 1):\n",
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
" \n",
" # Проверка наличия целевого признака\n",
" if target_colname not in df_input.columns:\n",
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
" \n",
" # Разделяем данные на признаки и целевую переменную\n",
" X = df_input.drop(columns=[target_colname]) # Признаки\n",
" y = df_input[[target_colname]] # Целевая переменная\n",
"\n",
" # Разделяем данные на обучающую и тестовую выборки\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y,\n",
" test_size=(1.0 - frac_train),\n",
" random_state=random_state\n",
" )\n",
" \n",
" return X_train, X_test, y_train, y_test\n",
"\n",
"# Применение функции для разделения данных\n",
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
" df, \n",
" target_colname=\"Outcome\", \n",
" frac_train=0.8, \n",
" random_state=42 # Убедитесь, что вы задали нужное значение random_state\n",
")\n",
"\n",
"# Для отображения результатов\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Определение перечня алгоритмов решения задачи аппроксимации (регрессии)"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 306,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
"\n",
"random_state = 9\n",
"\n",
"models = {\n",
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPRegressor(\n",
" activation=\"tanh\",\n",
" hidden_layer_sizes=(3,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}\n",
"\n"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 307,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-11-09 10:46:39 +04:00
"Model: linear\n",
2024-11-08 22:14:23 +04:00
"Model: linear_poly\n",
"Model: linear_interact\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import math\n",
"from pandas import DataFrame\n",
"from sklearn import metrics\n",
"\n",
"for model_name in models.keys():\n",
" print(f\"Model: {model_name}\")\n",
"\n",
" fitted_model = models[model_name][\"model\"].fit(\n",
" X_train.values, y_train.values.ravel()\n",
" )\n",
" y_train_pred = fitted_model.predict(X_train.values)\n",
" y_test_pred = fitted_model.predict(X_test.values)\n",
" models[model_name][\"fitted\"] = fitted_model\n",
" models[model_name][\"train_preds\"] = y_train_pred\n",
" models[model_name][\"preds\"] = y_test_pred\n",
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_train, y_train_pred)\n",
" )\n",
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод результатов оценки"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 308,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row0_col0, #T_222a8_row0_col1 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row0_col2 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #7e03a8;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row0_col3, #T_222a8_row7_col2 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row1_col0, #T_222a8_row2_col0 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #25ab82;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row1_col1 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #24868e;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row1_col2, #T_222a8_row2_col2 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #a11b9b;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row1_col3 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #d5546e;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row2_col1 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #24878e;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row2_col3 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #d5536f;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row3_col0, #T_222a8_row6_col0 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #20a486;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row3_col1 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #228d8d;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row3_col2 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #9a169f;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row3_col3 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #cf4c74;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row4_col0 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #21a685;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row4_col1 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #21918c;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row4_col2 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #a51f99;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row4_col3 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #cc4977;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row5_col0 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #25838e;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row5_col1 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #1f9f88;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row5_col2, #T_222a8_row7_col3 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row5_col3 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #bf3984;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row6_col1 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #1fa287;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row6_col2 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #a31e9a;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row6_col3 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #bc3587;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-09 10:46:39 +04:00
"#T_222a8_row7_col0, #T_222a8_row7_col1 {\n",
2024-11-08 22:14:23 +04:00
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"</style>\n",
2024-11-09 10:46:39 +04:00
"<table id=\"T_222a8\">\n",
2024-11-08 22:14:23 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_222a8_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
" <th id=\"T_222a8_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
" <th id=\"T_222a8_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
" <th id=\"T_222a8_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_222a8_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
" <td id=\"T_222a8_row0_col0\" class=\"data row0 col0\" >0.240052</td>\n",
" <td id=\"T_222a8_row0_col1\" class=\"data row0 col1\" >0.405871</td>\n",
" <td id=\"T_222a8_row0_col2\" class=\"data row0 col2\" >0.559210</td>\n",
" <td id=\"T_222a8_row0_col3\" class=\"data row0 col3\" >0.282505</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_222a8_level0_row1\" class=\"row_heading level0 row1\" >linear</th>\n",
" <td id=\"T_222a8_row1_col0\" class=\"data row1 col0\" >0.396793</td>\n",
" <td id=\"T_222a8_row1_col1\" class=\"data row1 col1\" >0.413576</td>\n",
" <td id=\"T_222a8_row1_col2\" class=\"data row1 col2\" >0.590024</td>\n",
" <td id=\"T_222a8_row1_col3\" class=\"data row1 col3\" >0.255003</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_222a8_level0_row2\" class=\"row_heading level0 row2\" >ridge</th>\n",
" <td id=\"T_222a8_row2_col0\" class=\"data row2 col0\" >0.396822</td>\n",
" <td id=\"T_222a8_row2_col1\" class=\"data row2 col1\" >0.414236</td>\n",
" <td id=\"T_222a8_row2_col2\" class=\"data row2 col2\" >0.590431</td>\n",
" <td id=\"T_222a8_row2_col3\" class=\"data row2 col3\" >0.252623</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_222a8_level0_row3\" class=\"row_heading level0 row3\" >linear_poly</th>\n",
" <td id=\"T_222a8_row3_col0\" class=\"data row3 col0\" >0.370076</td>\n",
" <td id=\"T_222a8_row3_col1\" class=\"data row3 col1\" >0.422852</td>\n",
" <td id=\"T_222a8_row3_col2\" class=\"data row3 col2\" >0.584147</td>\n",
" <td id=\"T_222a8_row3_col3\" class=\"data row3 col3\" >0.221209</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_222a8_level0_row4\" class=\"row_heading level0 row4\" >linear_interact</th>\n",
" <td id=\"T_222a8_row4_col0\" class=\"data row4 col0\" >0.380128</td>\n",
" <td id=\"T_222a8_row4_col1\" class=\"data row4 col1\" >0.426815</td>\n",
" <td id=\"T_222a8_row4_col2\" class=\"data row4 col2\" >0.593532</td>\n",
" <td id=\"T_222a8_row4_col3\" class=\"data row4 col3\" >0.206543</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_222a8_level0_row5\" class=\"row_heading level0 row5\" >decision_tree</th>\n",
" <td id=\"T_222a8_row5_col0\" class=\"data row5 col0\" >0.249880</td>\n",
" <td id=\"T_222a8_row5_col1\" class=\"data row5 col1\" >0.445708</td>\n",
" <td id=\"T_222a8_row5_col2\" class=\"data row5 col2\" >0.520376</td>\n",
" <td id=\"T_222a8_row5_col3\" class=\"data row5 col3\" >0.134743</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_222a8_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
" <td id=\"T_222a8_row6_col0\" class=\"data row6 col0\" >0.373319</td>\n",
" <td id=\"T_222a8_row6_col1\" class=\"data row6 col1\" >0.450285</td>\n",
" <td id=\"T_222a8_row6_col2\" class=\"data row6 col2\" >0.592157</td>\n",
" <td id=\"T_222a8_row6_col3\" class=\"data row6 col3\" >0.116883</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" <tr>\n",
2024-11-09 10:46:39 +04:00
" <th id=\"T_222a8_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_222a8_row7_col0\" class=\"data row7 col0\" >0.623529</td>\n",
" <td id=\"T_222a8_row7_col1\" class=\"data row7 col1\" >0.544323</td>\n",
" <td id=\"T_222a8_row7_col2\" class=\"data row7 col2\" >0.658689</td>\n",
" <td id=\"T_222a8_row7_col3\" class=\"data row7 col3\" >-0.290498</td>\n",
2024-11-08 22:14:23 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-11-09 10:46:39 +04:00
"<pandas.io.formats.style.Styler at 0x203c0120950>"
2024-11-08 22:14:23 +04:00
]
},
2024-11-09 10:46:39 +04:00
"execution_count": 308,
2024-11-08 22:14:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"Вывод реального и \"спрогнозированного\" результата для обучающей и тестовой выборок\n",
"\n",
"Получение лучшей модели\n"
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 309,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'random_forest'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
"\n",
2024-11-09 10:46:39 +04:00
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
"execution_count": 310,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 36 candidates, totalling 180 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\5_semester\\AIM\\rep\\AIM-PIbd-31-Razubaev-S-M\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие параметры: {'max_depth': 10, 'min_samples_split': 10, 'n_estimators': 200}\n",
"Лучший результат (MSE): 0.15427721639903466\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"from sklearn.ensemble import RandomForestRegressor # Используем регрессор\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"\n",
"df.dropna(inplace=True) \n",
"# Предикторы и целевая переменная\n",
"X = df[[\"Glucose\", \"Age\", \"BloodPressure\", \"DiabetesPedigreeFunction\"]]\n",
"y = df['Outcome'] # Целевая переменная для регрессии\n",
"\n",
"\n",
"model = RandomForestRegressor() \n",
"\n",
"param_grid = {\n",
" 'n_estimators': [50, 100, 200], \n",
" 'max_depth': [None, 10, 20, 30], \n",
" 'min_samples_split': [2, 5, 10] \n",
"}\n",
"\n",
"# 3. Подбор гиперпараметров с помощью Grid Search\n",
"grid_search = GridSearchCV(estimator=model, param_grid=param_grid,\n",
" scoring='neg_mean_squared_error', cv=5, n_jobs=-1, verbose=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"grid_search.fit(X_train, y_train)\n",
"\n",
"# 4. Результаты подбора гиперпараметров\n",
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
"print(\"Лучший результат (MSE):\", -grid_search.best_score_) "
2024-11-08 22:14:23 +04:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-11-09 10:46:39 +04:00
"Обучение модели с новыми гиперпараметрами и сравнение новых и старых данных"
2024-11-08 22:14:23 +04:00
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 319,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
2024-11-09 10:46:39 +04:00
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 36 candidates, totalling 180 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\5_semester\\AIM\\rep\\AIM-PIbd-31-Razubaev-S-M\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\5_semester\\AIM\\rep\\AIM-PIbd-31-Razubaev-S-M\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\5_semester\\AIM\\rep\\AIM-PIbd-31-Razubaev-S-M\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\5_semester\\AIM\\rep\\AIM-PIbd-31-Razubaev-S-M\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\5_semester\\AIM\\rep\\AIM-PIbd-31-Razubaev-S-M\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\5_semester\\AIM\\rep\\AIM-PIbd-31-Razubaev-S-M\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Старые параметры: {'max_depth': 30, 'min_samples_split': 10, 'n_estimators': 50}\n",
"Лучший результат (MSE) на старых параметрах: 0.1543002886456971\n",
"\n",
"Новые параметры: {'max_depth': 20, 'min_samples_split': 10, 'n_estimators': 200}\n",
"Лучший результат (MSE) на новых параметрах: 0.15791709286040012\n",
"Среднеквадратическая ошибка (MSE) на тестовых данных: 0.16712438177283198\n",
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 0.408808490338486\n"
]
2024-11-08 22:14:23 +04:00
}
],
"source": [
2024-11-09 10:46:39 +04:00
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"old_param_grid = {\n",
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
" 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n",
"}\n",
"\n",
"old_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
" param_grid=old_param_grid,\n",
" scoring='neg_mean_squared_error', cv=5, n_jobs=-1, verbose=2)\n",
"\n",
"old_grid_search.fit(X_train, y_train)\n",
"\n",
"old_best_params = old_grid_search.best_params_\n",
"old_best_mse = -old_grid_search.best_score_ # Меняем знак, так как берем отрицательное значение MSE\n",
"\n",
"new_param_grid = {\n",
" 'n_estimators': [200],\n",
" 'max_depth': [20],\n",
" 'min_samples_split': [10]\n",
"}\n",
"\n",
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
" param_grid=new_param_grid,\n",
" scoring='neg_mean_squared_error', cv=2)\n",
"\n",
"new_grid_search.fit(X_train, y_train)\n",
"\n",
"new_best_params = new_grid_search.best_params_\n",
"new_best_mse = -new_grid_search.best_score_ # Меняем знак, так как берем отрицательное значение MSE\n",
"\n",
"model_best = RandomForestRegressor(**new_best_params)\n",
"model_best.fit(X_train, y_train)\n",
"\n",
"model_oldbest = RandomForestRegressor(**old_best_params)\n",
"model_oldbest.fit(X_train, y_train)\n",
"\n",
"y_pred = model_best.predict(X_test)\n",
"y_oldpred = model_oldbest.predict(X_test)\n",
"\n",
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"\n",
"print(\"Старые параметры:\", old_best_params)\n",
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
"print(\"\\nНовые параметры:\", new_best_params)\n",
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)"
2024-11-08 22:14:23 +04:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-11-09 10:46:39 +04:00
"Визуализация"
2024-11-08 22:14:23 +04:00
]
},
{
"cell_type": "code",
2024-11-09 10:46:39 +04:00
"execution_count": 329,
2024-11-08 22:14:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
2024-11-09 10:46:39 +04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1gAAAHWCAYAAABquigpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOx9d5jcxP3+q7K718/n3nHvYLDp1aEZm5aEb4CQhJJAKCGh/AIJhJ7QWwgBEqqpAQIEAhxgYzAlGDDFYIyNO+6+Yl/bqvb7Q7vSSBrtStvvbt7n8eOVbpqk0Wg+877z+XCapmlgYGBgYGBgYGBgYGBgyBl8qRvAwMDAwMDAwMDAwMDQU8AMLAYGBgYGBgYGBgYGhjyBGVgMDAwMDAwMDAwMDAx5AjOwGBgYGBgYGBgYGBgY8gRmYDEwMDAwMDAwMDAwMOQJzMBiYGBgYGBgYGBgYGDIE5iBxcDAwMDAwMDAwMDAkCcwA4uBgYGBgYGBgYGBgSFPYAYWAwMDAwMDAwMDAwNDnsAMLAYGBgYGBgYGBl947bXXsHTpUuP45ZdfxvLly0vXIAaGMgIzsBgYuhnWrl2Lc889F2PGjEFFRQXq6upw0EEH4Z577kE0Gi118xgYGBgYegGWLVuGiy66CKtXr8bHH3+M8847D52dnaVuFgNDWYDTNE0rdSMYGBi84fXXX8dPfvIThEIhnH766Zg2bRoSiQQ+/PBDvPjiizjzzDPx4IMPlrqZDAwMDAw9HM3NzTjwwAOxZs0aAMCPf/xjvPjiiyVuFQNDeYAZWAwM3QTr16/HHnvsgeHDh+Odd97BkCFDLH9fs2YNXn/9dVx00UUlaiEDAwMDQ29CPB7HN998g6qqKkyePLnUzWFgKBswiSADQzfBbbfdhq6uLjzyyCMO4woAxo0bZzGuOI7DhRdeiKeffhoTJ05ERUUFZs6ciffff9+S7/vvv8cFF1yAiRMnorKyEv369cNPfvITbNiwwZJu3rx54DjO+FdVVYXdd98dDz/8sCXdmWeeiZqaGkf7XnjhBXAch0WLFlnOf/LJJzjmmGNQX1+PqqoqHHbYYfjf//5nSXPdddeB4zi0tLRYzn/22WfgOA7z5s2z1D9q1ChLuk2bNqGyshIcxzmu64033sAhhxyC6upq1NbW4thjj/W0j8B+P+z/rrvuOkf7V65ciZNPPhl1dXXo168fLrroIsRiMUfZTz31FGbOnInKykr07dsXp556KjZt2kRth1v99vsci8Vw3XXXYcKECaioqMCQIUPw4x//GGvXrgUAbNiwwXEvOzs7MXPmTIwePRrbtm0zzt9xxx048MAD0a9fP1RWVmLmzJl44YUXLPW1trZizpw5GD58OEKhEIYMGYKf/exn+P777y3pvJSVus4LL7zQcf64446zPO/Uddxxxx2OtNOmTcOsWbOM40WLFoHjOGp9Kdj707XXXgue57Fw4UJLul//+tcIBoP46quvXMtKXQfZNwDg9ttvB8dxlrblI3+6Z566T+n+nXnmmQDMvk6+O6qqYo899qC+f17f/1mzZmHatGmOtHfccYejvlGjRuG4445zvS+pZ5kqf8WKFaisrMTpp59uSffhhx9CEAT84Q9/cC0L0N/ZKVOmoKamBnV1ddh///3x8ssvW9L4af8rr7yCY489FkOHDkUoFMLYsWPx5z//GYqiWPLSni/t/gPexi6/z8Peh5YsWWL0B1o7Q6EQZs6cicmTJ/vqxwwMPR1iqRvAwMDgDa+++irGjBmDAw880HOe9957D8899xx+97vfIRQK4f7778cxxxyDTz/91JgYLFmyBB999BFOPfVUDB8+HBs2bMADDzyAWbNm4dtvv0VVVZWlzLvvvhv9+/dHR0cHHn30UZxzzjkYNWoUjjzySN/X9M4772DOnDmYOXOmMXF97LHHcPjhh+ODDz7Avvvu67tMGq655hqqIfPkk0/ijDPOwOzZs3HrrbciEonggQcewMEHH4wvv/zSYajRcMMNN2D06NHGcVdXF84//3xq2pNPPhmjRo3CzTffjI8//hh/+9vfsGvXLjzxxBNGmhtvvBFXX301Tj75ZJx99tlobm7Gvffei0MPPRRffvkl+vTp4yj3qKOOMiaSS5Yswd/+9jfL3xVFwXHHHYeFCxfi1FNPxUUXXYTOzk4sWLAA33zzDcaOHesoU5IknHTSSdi4cSP+97//WYz6e+65ByeccAJ+9rOfIZFI4Nlnn8VPfvITvPbaazj22GMBAIlEArW1tbjooovQr18/rF27Fvfeey++/vprLFu2zFdZ5YSrrroKr776Kn71q19h2bJlqK2txVtvvYWHHnoIf/7znzF9+nRf5bW1teHmm2/Ouj1u+TM98yOPPBJPPvmkkf6ll17Cf/7zH8s5Wr9I4cknn7Q8x3LD5MmT8ec//xmXXXYZ/u///g8nnHACwuEwzjzzTEyaNAk33HBD2vzhcBg/+tGPMGrUKESjUcybNw8nnXQSFi9enNW4NG/ePNTU1ODSSy9FTU0N3nnnHVxzzTXo6OjA7bff7ru8fIxdXpDJEE0h137MwNDjoDEwMJQ92tvbNQDaiSee6DkPAA2A9tlnnxnnvv/+e62iokL70Y9+ZJyLRCKOvIsXL9YAaE888YRx7rHHHtMAaOvXrzfOrVq1SgOg3Xbbbca5M844Q6uurnaU+e9//1sDoL377ruapmmaqqra+PHjtdmzZ2uqqlraM3r0aO2oo44yzl177bUaAK25udlS5pIlSzQA2mOPPWapf7fddjOOv/nmG43neW3OnDmW9nd2dmp9+vTRzjnnHEuZ27dv1+rr6x3n7UjdjyVLlljONzc3awC0a6+91tH+E044wZL2ggsu0ABoX331laZpmrZhwwZNEATtxhtvtKRbtmyZJoqi43wikdAAaBdeeKFxzn6fNU3THn30UQ2AdtdddzmuI3Xv169fb9xLVVW1n/3sZ1pVVZX2ySefOPLY+0wikdCmTZumHX744Y60JG677TYNgNbS0uK7LADab37zG0eZxx57rOV5p67j9ttvd6SdOnWqdthhhxnH7777rgZA+/e//+3aZnt/0jT9eQSDQe3ss8/Wdu3apQ0bNkzbe++9NUmSXMshr4PsG5dffrk2cOBAbebMmZa25ZrfyzMnkeqjNNjf/Vgspo0cOdJ4p+zvn5f3X9M07bDDDtOmTp3qSHv77bc7xprddttNO/bYY6nt0zTzWZLlK4qiHXzwwdqgQYO0lpYW7Te/+Y0miqLjnfWCpqYmDYB2xx13ZNV+2jh77rnnalVVVVosFjPOcRynXXPNNZZ09vvvZ+zy+zzIPtTY2KgB0I455hhH38i1HzMw9HQwiSADQzdAR0cHAKC2ttZXvgMOOAAzZ840jkeOHIkTTzwRb731liFNqaysNP4uSRJaW1sxbtw49OnTB1988YWjzF27dqGlpQXr1q3D3XffDUEQcNhhhznStbS0WP7ZvUstXboUq1evxmmnnYbW1lYjXTgcxhFHHIH3338fqqpa8uzcudNSZnt7e8Z7cMUVV2DGjBn4yU9+Yjm/YMECtLW14ac//amlTEEQsN9+++Hdd9/NWLZf/OY3v7Ec//a3vwUANDY2AtBZBFVVcfLJJ1vaNHjwYIwfP97RphQrV1FRkbbeF198Ef379zfqI2GX/gDAZZddhqeffhrPP/88dbWe7DO7du1Ce3s7DjnkEGp/6ezsRFNTExYvXox//etfmDp1Kvr27ZtVWbFYzNGvJEmiXnMkEnGktcuxyDa2tLSgra2N+nc7pk2bhuuvvx4PP/wwZs+ejZaWFjz++OMQRX+ikC1btuDee+/F1VdfTZVx5ZLf7zP3g/vuuw+tra249tprXdNkev9TUBTFkTYSiVDTSpKElpYWtLa2QpbljO3keR7z5s1DV1cX5syZg/vvvx9XXHEF9t57b0/Xmapv7dq1uOWWW8DzPA466KCs2k/281R/O+SQQxCJRLBy5UrjbwMHDsTmzZvTtiubscvr80hB0zRcccUVOOmkk7DffvulTZtrP2Zg6IlgEkEGhm6Auro6APDtAnf8+PGOcxM
2024-11-08 22:14:23 +04:00
"text/plain": [
2024-11-09 10:46:39 +04:00
"<Figure size 1000x500 with 1 Axes>"
2024-11-08 22:14:23 +04:00
]
},
"metadata": {},
2024-11-09 10:46:39 +04:00
"output_type": "display_data"
2024-11-08 22:14:23 +04:00
}
],
"source": [
2024-11-09 10:46:39 +04:00
"plt.figure(figsize=(10, 5))\n",
"plt.plot(y_test.values, label='Истинные значения', color='blue', linewidth=2)\n",
"plt.plot(y_oldpred, label='Предсказанные значения(после)', color='red', linestyle='--', linewidth=2)\n",
"plt.plot(y_pred, label='Предсказанные значения(до)', color='green', linestyle='-', linewidth=2)\n",
"\n",
"plt.title('Сравнение предсказанных и истинных значений')\n",
"plt.xlabel('Подбор параметров')\n",
"plt.ylabel('Значения')\n",
"plt.grid()\n",
"plt.legend( loc ='lower right')\n",
"plt.show()"
2024-11-08 22:14:23 +04:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}