3832 lines
377 KiB
Plaintext
Raw Permalink Normal View History

2024-11-27 18:20:28 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подключили датафрейм и выгрузили данные"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['Pregnancies', 'Glucose', 'BloodPressure', 'SkinThickness', 'Insulin',\n",
" 'BMI', 'DiabetesPedigreeFunction', 'Age', 'Outcome'],\n",
" dtype='object')\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>6</td>\n",
" <td>148</td>\n",
" <td>72</td>\n",
" <td>35</td>\n",
" <td>0</td>\n",
" <td>33.6</td>\n",
" <td>0.627</td>\n",
" <td>50</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>85</td>\n",
" <td>66</td>\n",
" <td>29</td>\n",
" <td>0</td>\n",
" <td>26.6</td>\n",
" <td>0.351</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>8</td>\n",
" <td>183</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>23.3</td>\n",
" <td>0.672</td>\n",
" <td>32</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>89</td>\n",
" <td>66</td>\n",
" <td>23</td>\n",
" <td>94</td>\n",
" <td>28.1</td>\n",
" <td>0.167</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>137</td>\n",
" <td>40</td>\n",
" <td>35</td>\n",
" <td>168</td>\n",
" <td>43.1</td>\n",
" <td>2.288</td>\n",
" <td>33</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>763</th>\n",
" <td>10</td>\n",
" <td>101</td>\n",
" <td>76</td>\n",
" <td>48</td>\n",
" <td>180</td>\n",
" <td>32.9</td>\n",
" <td>0.171</td>\n",
" <td>63</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>764</th>\n",
" <td>2</td>\n",
" <td>122</td>\n",
" <td>70</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" <td>36.8</td>\n",
" <td>0.340</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>765</th>\n",
" <td>5</td>\n",
" <td>121</td>\n",
" <td>72</td>\n",
" <td>23</td>\n",
" <td>112</td>\n",
" <td>26.2</td>\n",
" <td>0.245</td>\n",
" <td>30</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>766</th>\n",
" <td>1</td>\n",
" <td>126</td>\n",
" <td>60</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30.1</td>\n",
" <td>0.349</td>\n",
" <td>47</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>767</th>\n",
" <td>1</td>\n",
" <td>93</td>\n",
" <td>70</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" <td>30.4</td>\n",
" <td>0.315</td>\n",
" <td>23</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>768 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"0 6 148 72 35 0 33.6 \n",
"1 1 85 66 29 0 26.6 \n",
"2 8 183 64 0 0 23.3 \n",
"3 1 89 66 23 94 28.1 \n",
"4 0 137 40 35 168 43.1 \n",
".. ... ... ... ... ... ... \n",
"763 10 101 76 48 180 32.9 \n",
"764 2 122 70 27 0 36.8 \n",
"765 5 121 72 23 112 26.2 \n",
"766 1 126 60 0 0 30.1 \n",
"767 1 93 70 31 0 30.4 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"0 0.627 50 1 \n",
"1 0.351 31 0 \n",
"2 0.672 32 1 \n",
"3 0.167 21 0 \n",
"4 2.288 33 1 \n",
".. ... ... ... \n",
"763 0.171 63 0 \n",
"764 0.340 27 0 \n",
"765 0.245 30 0 \n",
"766 0.349 47 1 \n",
"767 0.315 23 0 \n",
"\n",
"[768 rows x 9 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"from sklearn import set_config\n",
"set_config(transform_output=\"pandas\")\n",
"random_state = 42\n",
"df = pd.read_csv(\"data/diabetes.csv\")\n",
"print(df.columns)\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Бизнес-цели\n",
"1. Предсказание риска развития диабета. Будем классифицировать пациентов на основе медданных для того, чтобы определить у кого есть риск развития диабета(будем использовать целевой признак \"Outcome\"). Актуальность для раннего выявления диабета.\n",
"2. Анализ ключевых факторов, влияющих на диабет. Предсказание вероятности развития диабета на основе медданных. Актуальность для планирвоания лечения.\n",
"## Определение достижимого уровня качества модели для первой задачи\n",
"Разделение данных на обучающую и тестовые выборки 80/20 для задачи классификации"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>353</th>\n",
" <td>1</td>\n",
" <td>90</td>\n",
" <td>62</td>\n",
" <td>12</td>\n",
" <td>43</td>\n",
" <td>27.2</td>\n",
" <td>0.580</td>\n",
" <td>24</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>711</th>\n",
" <td>5</td>\n",
" <td>126</td>\n",
" <td>78</td>\n",
" <td>27</td>\n",
" <td>22</td>\n",
" <td>29.6</td>\n",
" <td>0.439</td>\n",
" <td>40</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>373</th>\n",
" <td>2</td>\n",
" <td>105</td>\n",
" <td>58</td>\n",
" <td>40</td>\n",
" <td>94</td>\n",
" <td>34.9</td>\n",
" <td>0.225</td>\n",
" <td>25</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46</th>\n",
" <td>1</td>\n",
" <td>146</td>\n",
" <td>56</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>29.7</td>\n",
" <td>0.564</td>\n",
" <td>29</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>682</th>\n",
" <td>0</td>\n",
" <td>95</td>\n",
" <td>64</td>\n",
" <td>39</td>\n",
" <td>105</td>\n",
" <td>44.6</td>\n",
" <td>0.366</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>451</th>\n",
" <td>2</td>\n",
" <td>134</td>\n",
" <td>70</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>28.9</td>\n",
" <td>0.542</td>\n",
" <td>23</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>113</th>\n",
" <td>4</td>\n",
" <td>76</td>\n",
" <td>62</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>34.0</td>\n",
" <td>0.391</td>\n",
" <td>25</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>556</th>\n",
" <td>1</td>\n",
" <td>97</td>\n",
" <td>70</td>\n",
" <td>40</td>\n",
" <td>0</td>\n",
" <td>38.1</td>\n",
" <td>0.218</td>\n",
" <td>30</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>667</th>\n",
" <td>10</td>\n",
" <td>111</td>\n",
" <td>70</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" <td>27.5</td>\n",
" <td>0.141</td>\n",
" <td>40</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>107</th>\n",
" <td>4</td>\n",
" <td>144</td>\n",
" <td>58</td>\n",
" <td>28</td>\n",
" <td>140</td>\n",
" <td>29.5</td>\n",
" <td>0.287</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"353 1 90 62 12 43 27.2 \n",
"711 5 126 78 27 22 29.6 \n",
"373 2 105 58 40 94 34.9 \n",
"46 1 146 56 0 0 29.7 \n",
"682 0 95 64 39 105 44.6 \n",
".. ... ... ... ... ... ... \n",
"451 2 134 70 0 0 28.9 \n",
"113 4 76 62 0 0 34.0 \n",
"556 1 97 70 40 0 38.1 \n",
"667 10 111 70 27 0 27.5 \n",
"107 4 144 58 28 140 29.5 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"353 0.580 24 0 \n",
"711 0.439 40 0 \n",
"373 0.225 25 0 \n",
"46 0.564 29 0 \n",
"682 0.366 22 0 \n",
".. ... ... ... \n",
"451 0.542 23 1 \n",
"113 0.391 25 0 \n",
"556 0.218 30 0 \n",
"667 0.141 40 1 \n",
"107 0.287 37 0 \n",
"\n",
"[614 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>353</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>711</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>373</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>682</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>451</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>113</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>556</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>667</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>107</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Outcome\n",
"353 0\n",
"711 0\n",
"373 0\n",
"46 0\n",
"682 0\n",
".. ...\n",
"451 1\n",
"113 0\n",
"556 0\n",
"667 1\n",
"107 0\n",
"\n",
"[614 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>44</th>\n",
" <td>7</td>\n",
" <td>159</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>27.4</td>\n",
" <td>0.294</td>\n",
" <td>40</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>672</th>\n",
" <td>10</td>\n",
" <td>68</td>\n",
" <td>106</td>\n",
" <td>23</td>\n",
" <td>49</td>\n",
" <td>35.5</td>\n",
" <td>0.285</td>\n",
" <td>47</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>700</th>\n",
" <td>2</td>\n",
" <td>122</td>\n",
" <td>76</td>\n",
" <td>27</td>\n",
" <td>200</td>\n",
" <td>35.9</td>\n",
" <td>0.483</td>\n",
" <td>26</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>630</th>\n",
" <td>7</td>\n",
" <td>114</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>27.4</td>\n",
" <td>0.732</td>\n",
" <td>34</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>81</th>\n",
" <td>2</td>\n",
" <td>74</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.102</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32</th>\n",
" <td>3</td>\n",
" <td>88</td>\n",
" <td>58</td>\n",
" <td>11</td>\n",
" <td>54</td>\n",
" <td>24.8</td>\n",
" <td>0.267</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>637</th>\n",
" <td>2</td>\n",
" <td>94</td>\n",
" <td>76</td>\n",
" <td>18</td>\n",
" <td>66</td>\n",
" <td>31.6</td>\n",
" <td>0.649</td>\n",
" <td>23</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>593</th>\n",
" <td>2</td>\n",
" <td>82</td>\n",
" <td>52</td>\n",
" <td>22</td>\n",
" <td>115</td>\n",
" <td>28.5</td>\n",
" <td>1.699</td>\n",
" <td>25</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>425</th>\n",
" <td>4</td>\n",
" <td>184</td>\n",
" <td>78</td>\n",
" <td>39</td>\n",
" <td>277</td>\n",
" <td>37.0</td>\n",
" <td>0.264</td>\n",
" <td>31</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>273</th>\n",
" <td>1</td>\n",
" <td>71</td>\n",
" <td>78</td>\n",
" <td>50</td>\n",
" <td>45</td>\n",
" <td>33.2</td>\n",
" <td>0.422</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>154 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"44 7 159 64 0 0 27.4 \n",
"672 10 68 106 23 49 35.5 \n",
"700 2 122 76 27 200 35.9 \n",
"630 7 114 64 0 0 27.4 \n",
"81 2 74 0 0 0 0.0 \n",
".. ... ... ... ... ... ... \n",
"32 3 88 58 11 54 24.8 \n",
"637 2 94 76 18 66 31.6 \n",
"593 2 82 52 22 115 28.5 \n",
"425 4 184 78 39 277 37.0 \n",
"273 1 71 78 50 45 33.2 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"44 0.294 40 0 \n",
"672 0.285 47 0 \n",
"700 0.483 26 0 \n",
"630 0.732 34 1 \n",
"81 0.102 22 0 \n",
".. ... ... ... \n",
"32 0.267 22 0 \n",
"637 0.649 23 0 \n",
"593 1.699 25 0 \n",
"425 0.264 31 1 \n",
"273 0.422 21 0 \n",
"\n",
"[154 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>44</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>672</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>700</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>630</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>81</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>637</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>593</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>425</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>273</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>154 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Outcome\n",
"44 0\n",
"672 0\n",
"700 0\n",
"630 1\n",
"81 0\n",
".. ...\n",
"32 0\n",
"637 0\n",
"593 0\n",
"425 1\n",
"273 0\n",
"\n",
"[154 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"random_state = 42\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" X = df_input\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ]\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"Outcome\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Формирование конвейера для классификации данных\n",
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
"\n",
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
"\n",
"features_preprocessing -- трансформер для предобработки признаков\n",
"\n",
"features_engineering -- трансформер для конструирования признаков\n",
"\n",
"drop_columns -- трансформер для удаления колонок\n",
"\n",
"features_postprocessing -- трансформер для унитарного кодирования новых признаков\n",
"\n",
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков\n",
"\n",
"Конвейер выполняется последовательно.\n",
"\n",
"Трансформер выполняет параллельно для указанного набора колонок."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.preprocessing import OneHotEncoder, StandardScaler\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Построение конвейеров предобработки\n",
"class DiabetesFeatures(BaseEstimator, TransformerMixin):\n",
" def __init__(self):\n",
" pass\n",
"\n",
" def fit(self, X, y=None):\n",
" return self\n",
"\n",
" def transform(self, X, y=None):\n",
" # Добавим признак отношения индекса массы тела и возраста.\n",
" X = X.copy()\n",
" X[\"Glucose_Insulin\"] = X[\"Glucose\"] / X[\"Insulin\"]\n",
" return X\n",
"\n",
" def get_feature_names_out(self, features_in):\n",
" new_features = [\"Glucose_Insulin\"]\n",
" return np.append(features_in, new_features, axis=0)\n",
"\n",
"# Обработка числовых данных. Числовой конвейр: заполнение пропущенных значений медианой и стандартизация\n",
"preprocessing_num_class = Pipeline(steps=[\n",
" ('imputer', SimpleImputer(strategy='median')),\n",
" ('scaler', StandardScaler())\n",
"])\n",
"\n",
"preprocessing_cat_class = Pipeline(steps=[\n",
" ('imputer', SimpleImputer(strategy='most_frequent')),\n",
" ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False, drop='first'))\n",
"])\n",
"\n",
"columns_to_drop = []\n",
"numeric_columns = [\"Pregnancies\", \"Glucose\", \"BloodPressure\", \"SkinThickness\", \"Insulin\",\n",
" \"BMI\", \"DiabetesPedigreeFunction\", \"Age\"]\n",
"cat_columns = [\"Outcome\"]\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"preprocessing_num\", preprocessing_num_class, numeric_columns),\n",
" (\"preprocessing_cat\", preprocessing_cat_class, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"features_postprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" ('preprocessing_cat', preprocessing_cat_class, [\"Outcome\"]),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"custom_features\", DiabetesFeatures()),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Работа конвейера:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome_1</th>\n",
" <th>Glucose_Insulin</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>353</th>\n",
" <td>-0.851355</td>\n",
" <td>-0.980131</td>\n",
" <td>-0.404784</td>\n",
" <td>-0.553973</td>\n",
" <td>-0.331319</td>\n",
" <td>-0.607678</td>\n",
" <td>0.310794</td>\n",
" <td>-0.792169</td>\n",
" <td>0.0</td>\n",
" <td>2.958266</td>\n",
" </tr>\n",
" <tr>\n",
" <th>711</th>\n",
" <td>0.356576</td>\n",
" <td>0.161444</td>\n",
" <td>0.465368</td>\n",
" <td>0.392787</td>\n",
" <td>-0.526398</td>\n",
" <td>-0.302139</td>\n",
" <td>-0.116439</td>\n",
" <td>0.561034</td>\n",
" <td>0.0</td>\n",
" <td>-0.306696</td>\n",
" </tr>\n",
" <tr>\n",
" <th>373</th>\n",
" <td>-0.549372</td>\n",
" <td>-0.504474</td>\n",
" <td>-0.622322</td>\n",
" <td>1.213312</td>\n",
" <td>0.142444</td>\n",
" <td>0.372594</td>\n",
" <td>-0.764862</td>\n",
" <td>-0.707594</td>\n",
" <td>0.0</td>\n",
" <td>-3.541575</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46</th>\n",
" <td>-0.851355</td>\n",
" <td>0.795653</td>\n",
" <td>-0.731091</td>\n",
" <td>-1.311380</td>\n",
" <td>-0.730766</td>\n",
" <td>-0.289408</td>\n",
" <td>0.262314</td>\n",
" <td>-0.369293</td>\n",
" <td>0.0</td>\n",
" <td>-1.088792</td>\n",
" </tr>\n",
" <tr>\n",
" <th>682</th>\n",
" <td>-1.153338</td>\n",
" <td>-0.821579</td>\n",
" <td>-0.296015</td>\n",
" <td>1.150195</td>\n",
" <td>0.244628</td>\n",
" <td>1.607482</td>\n",
" <td>-0.337630</td>\n",
" <td>-0.961320</td>\n",
" <td>0.0</td>\n",
" <td>-3.358486</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>451</th>\n",
" <td>-0.549372</td>\n",
" <td>0.415128</td>\n",
" <td>0.030292</td>\n",
" <td>-1.311380</td>\n",
" <td>-0.730766</td>\n",
" <td>-0.391255</td>\n",
" <td>0.195653</td>\n",
" <td>-0.876744</td>\n",
" <td>1.0</td>\n",
" <td>-0.568071</td>\n",
" </tr>\n",
" <tr>\n",
" <th>113</th>\n",
" <td>0.054593</td>\n",
" <td>-1.424076</td>\n",
" <td>-0.404784</td>\n",
" <td>-1.311380</td>\n",
" <td>-0.730766</td>\n",
" <td>0.258017</td>\n",
" <td>-0.261879</td>\n",
" <td>-0.707594</td>\n",
" <td>0.0</td>\n",
" <td>1.948744</td>\n",
" </tr>\n",
" <tr>\n",
" <th>556</th>\n",
" <td>-0.851355</td>\n",
" <td>-0.758158</td>\n",
" <td>0.030292</td>\n",
" <td>1.213312</td>\n",
" <td>-0.730766</td>\n",
" <td>0.779980</td>\n",
" <td>-0.786072</td>\n",
" <td>-0.284718</td>\n",
" <td>0.0</td>\n",
" <td>1.037483</td>\n",
" </tr>\n",
" <tr>\n",
" <th>667</th>\n",
" <td>1.866489</td>\n",
" <td>-0.314212</td>\n",
" <td>0.030292</td>\n",
" <td>0.392787</td>\n",
" <td>-0.730766</td>\n",
" <td>-0.569486</td>\n",
" <td>-1.019383</td>\n",
" <td>0.561034</td>\n",
" <td>1.0</td>\n",
" <td>0.429976</td>\n",
" </tr>\n",
" <tr>\n",
" <th>107</th>\n",
" <td>0.054593</td>\n",
" <td>0.732232</td>\n",
" <td>-0.622322</td>\n",
" <td>0.455904</td>\n",
" <td>0.569759</td>\n",
" <td>-0.314870</td>\n",
" <td>-0.577001</td>\n",
" <td>0.307308</td>\n",
" <td>0.0</td>\n",
" <td>1.285160</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"353 -0.851355 -0.980131 -0.404784 -0.553973 -0.331319 -0.607678 \n",
"711 0.356576 0.161444 0.465368 0.392787 -0.526398 -0.302139 \n",
"373 -0.549372 -0.504474 -0.622322 1.213312 0.142444 0.372594 \n",
"46 -0.851355 0.795653 -0.731091 -1.311380 -0.730766 -0.289408 \n",
"682 -1.153338 -0.821579 -0.296015 1.150195 0.244628 1.607482 \n",
".. ... ... ... ... ... ... \n",
"451 -0.549372 0.415128 0.030292 -1.311380 -0.730766 -0.391255 \n",
"113 0.054593 -1.424076 -0.404784 -1.311380 -0.730766 0.258017 \n",
"556 -0.851355 -0.758158 0.030292 1.213312 -0.730766 0.779980 \n",
"667 1.866489 -0.314212 0.030292 0.392787 -0.730766 -0.569486 \n",
"107 0.054593 0.732232 -0.622322 0.455904 0.569759 -0.314870 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome_1 Glucose_Insulin \n",
"353 0.310794 -0.792169 0.0 2.958266 \n",
"711 -0.116439 0.561034 0.0 -0.306696 \n",
"373 -0.764862 -0.707594 0.0 -3.541575 \n",
"46 0.262314 -0.369293 0.0 -1.088792 \n",
"682 -0.337630 -0.961320 0.0 -3.358486 \n",
".. ... ... ... ... \n",
"451 0.195653 -0.876744 1.0 -0.568071 \n",
"113 -0.261879 -0.707594 0.0 1.948744 \n",
"556 -0.786072 -0.284718 0.0 1.037483 \n",
"667 -1.019383 0.561034 1.0 0.429976 \n",
"107 -0.577001 0.307308 0.0 1.285160 \n",
"\n",
"[614 rows x 10 columns]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Формирование набора моделей для классификации\n",
"logistic -- логистическая регрессия\n",
"\n",
"ridge -- гребневая регрессия\n",
"\n",
"decision_tree -- дерево решений\n",
"\n",
"knn -- k-ближайших соседей\n",
"\n",
"naive_bayes -- наивный Байесовский классификатор\n",
"\n",
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
"\n",
"random_forest -- метод случайного леса (набор деревьев решений)\n",
"\n",
"mlp -- многослойный персептрон (нейронная сеть)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"random_state = 42\n",
"\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression(random_state=random_state)},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\", random_state=random_state)},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210, random_state=random_state)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение моделей на обучающем наборе данных и оценка на тестовом"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Сводная таблица оценок качества для использованных моделей классификации"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA3EAAAQ9CAYAAAD3ScTVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVwU9f8H8NcAwiI3iCwoIt544pWRZ4ZhpWna1yz6hebR5Z1npRzelkl4lqVoaWalppWWt6lkeaaJNyomoImAoFy7n98f5NYKi7s6wM7wen4f8/i6n/nM7Gc24eV79jMzkhBCgIiIiIiIiBTBpqIHQEREREREROZjEUdERERERKQgLOKIiIiIiIgUhEUcERERERGRgrCIIyIiIiIiUhAWcURERERERArCIo6IiIiIiEhBWMQREREREREpCIs4IiIiIiIiBWERR5VefHw8JEnCxYsXy2T/Fy9ehCRJiI+Pl2V/u3btgiRJ2LVrlyz7IyIiUouoqChIkmRWX0mSEBUVVbYDIiojLOKIrNSiRYtkK/yIiIiISD3sKnoARGoXEBCAO3fuoEqVKhZtt2jRIlSrVg0DBgwwau/UqRPu3LkDe3t7GUdJRESkfO+99x4mTpxY0cMgKnMs4ojKmCRJ0Gg0su3PxsZG1v0RERGpQU5ODpycnGBnx3/ekvpxOiVRCRYtWoQmTZrAwcEBfn5+eOutt5CRkVGs38KFC1GnTh04OjrikUcewS+//IIuXbqgS5cuhj4lXROXmpqKgQMHombNmnBwcICvry969epluC6vdu3a+PPPP7F7925IkgRJkgz7NHVN3IEDB/D000/Dw8MDTk5OaN68OT766CN5PxgiIiIrcPfat5MnT+Kll16Ch4cHOnToUOI1cXl5eRg9ejS8vb3h4uKCZ599FleuXClxv7t27UKbNm2g0WhQt25dfPzxxyavs/viiy/QunVrODo6wtPTE/3790dycnKZHC/RvXiqgugeUVFRiI6ORmhoKN544w2cPn0aixcvxu+//459+/YZpkUuXrwYw4YNQ8eOHTF69GhcvHgRvXv3hoeHB2rWrFnqe/Tt2xd//vknhg8fjtq1a+PatWvYunUrLl++jNq1ayM2NhbDhw+Hs7Mz3n33XQCAj4+Pyf1t3boVPXr0gK+vL0aOHAmtVovExER8//33GDlypHwfDhERkRX53//+h/r162PGjBkQQuDatWvF+gwePBhffPEFXnrpJTz22GPYsWMHnnnmmWL9jhw5gu7du8PX1xfR0dHQ6XSIiYmBt7d3sb7Tp0/H5MmT0a9fPwwePBjXr1/H/Pnz0alTJxw5cgTu7u5lcbhE/xJEldzy5csFAJGUlCSuXbsm7O3txZNPPil0Op2hz4IFCwQAsWzZMiGEEHl5ecLLy0u0bdtWFBQUGPrFx8cLAKJz586GtqSkJAFALF++XAghxM2bNwUA8f7775c6riZNmhjt566dO3cKAGLnzp1CCCEKCwtFYGCgCAgIEDdv3jTqq9frzf8giIiIFCIyMlIAEC+++GKJ7XcdPXpUABBvvvmmUb+XXnpJABCRkZGGtp49e4qqVauKv/76y9B29uxZYWdnZ7TPixcvCltbWzF9+nSjfR4/flzY2dkVaycqC5xOSfQf27ZtQ35+PkaNGgUbm39/PIYMGQJXV1f88MMPAICDBw/ixo0bGDJkiNHc+/DwcHh4eJT6Ho6OjrC3t8euXbtw8+bNhx7zkSNHkJSUhFGjRhU782fubZaJiIiU6PXXXy91/Y8//ggAGDFihFH7qFGjjF7rdDps27YNvXv3hp+fn6G9Xr16eOqpp4z6rlu3Dnq9Hv369cPff/9tWLRaLerXr4+dO3c+xBERmYfTKYn+49KlSwCAhg0bGrXb29ujTp06hvV3/79evXpG/ezs7FC7du1S38PBwQGzZ8/G22+/DR8fHzz66KPo0aMHXnnlFWi1WovHfP78eQBA06ZNLd6WiIhIyQIDA0tdf+nSJdjY2KBu3bpG7ffm/LVr13Dnzp1iuQ4Uz/qzZ89CCIH69euX+J6W3o2a6EGwiCOqAKNGjULPnj2xYcMG/PTTT5g8eTJmzpyJHTt2oGXLlhU9PCIiIkVwdHQs9/fU6/WQJAmbN2+Gra1tsfXOzs7lPiaqfDidkug/AgICAACnT582as/Pz0dSUpJh/d3/P3funFG/wsJCwx0m76du3bp4++238fPPP+PEiRPIz8/H3LlzDevNnQp59+ziiRMnzOpPRERUWQQEBECv1xtmrdx1b85Xr14dGo2mWK4DxbO+bt26EEIgMDAQoaGhxZZHH31U/gMhugeLOKL/CA0Nhb29PeLi4iCEMLR/9tlnyMzMNNzNqk2bNvDy8sLSpUtRWFho6Ldq1ar7Xud2+/Zt5ObmGrXVrVsXLi4uyMvLM7Q5OTmV+FiDe7Vq1QqBgYGIjY0t1v+/x0BERFTZ3L2eLS4uzqg9NjbW6LWtrS1CQ0OxYcMGXL161dB+7tw5bN682ahvnz59YGtri+jo6GI5K4TAjRs3ZDwCopJxOiXRf3h7e2PSpEmIjo5G9+7d8eyzz+L06dNYtGgR2rZti5dffhlA0TVyUVFRGD58OLp27Yp+/frh4sWLiI+PR926dUv9Fu3MmTN44okn0K9fPzRu3Bh2dnZYv3490tLS0L9/f0O/1q1bY/HixZg2bRrq1auH6tWro2vXrsX2Z2Njg8WLF6Nnz54IDg7GwIED4evri1OnTuHPP//ETz/9JP8HRUREpADBwcF48cUXsWjRImRmZuKxxx7D9u3bS/zGLSoqCj///DPat2+PN954AzqdDgsWLEDTpk1x9OhRQ7+6deti2rRpmDRpkuHxQi4uLkhKSsL69esxdOhQjB07thyPkiojFnFE94iKioK3tzcWLFiA0aNHw9PTE0OHDsWMGTOMLlYeNmwYhBCYO3cuxo4dixYtWmDjxo0YMWIENBqNyf37+/vjxRdfxPbt2/H555/Dzs4OjRo1wtq1a9G3b19DvylTpuDSpUuYM2cObt26hc6dO5dYxAFAWFgYdu7ciejoaMydOxd6vR5169bFkCFD5PtgiIiIFGjZsmXw9vbGqlWrsGHDBnTt2hU//PAD/P39jfq1bt0amzdvxtixYzF58mT4+/sjJiYGiYmJOHXqlFHfiRMnokGDBpg3bx6io6MBFOX7k08+iWeffbbcjo0qL0lwvhWRbPR6Pby9vdGnTx8sXbq0oodDRERED6l37974888/cfbs2YoeCpEBr4kjekC5ubnF5sKvXLkS6enp6NKlS8UMioiIiB7YnTt3jF6fPXsWP/74I3OdrA6/iSN6QLt27cLo0aPxv//9D15eXjh8+DA+++wzBAUF4dChQ7C3t6/oIRIREZEFfH19MWDAAMOzYRcvXoy8vDwcOXLE5HPhiCoCr4kjekC1a9eGv78/4uLikJ6eDk9PT7zyyiuYNWsWCzgiIiIF6t69O7788kukpqbCwcEBISEhmDFjBgs4sjqcTkn0gGrXro2NGzciNTUV+fn5SE1NxbJly1C9evWKHhqpxJ49e9CzZ0/4+flBkiRs2LDBaL0QAlOmTIGvry8cHR0RGhpa7JqN9PR0hIeHw9XVFe7u7hg0aBCys7PL8SiIiJRj+fLluHjxInJzc5GZmYktW7agVatWFT0ssiLWks0s4oiIrFROTg5atGiBhQsXlrh+zpw5iIuLw5IlS3DgwAE4OTkhLCzM6DmE4eHh+PPPP7F161Z8//332LNnD4YOHVpeh0BERKQq1pLNvCaOiEgBJEnC+vXr0bt3bwBFZ/r8/Pzw9ttvG55HlJmZCR8fH8THx6N///5ITExE48aN8fvvv6NNmzYAgC1btuDpp5/GlStX4OfnV1GHQ0REpHgVmc28Jq4S0ev1uHr1KlxcXEp9GDWRWgkhcOvWLfj5+cHGRt6JCLm5ucjPz7/v+9/7s+fg4AAHBweL3y8pKQmpqakIDQ01tLm5uaFdu3ZISEhA//79kZCQAHd3d0NIAEBoaChsbGxw4MABPPfccxa/LxHJi9lMlR2z+cGymUVcJXL16tV
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"n_rows = int(len(class_models) / 2)\n",
"n_cols = 2\n",
"\n",
"fig, ax = plt.subplots(n_rows, n_cols, figsize=(12, 10), sharex=False, sharey=False)\n",
"\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"No Diabetes\", \"Diabetes\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Значение 100 - это количество вернных диагнозов (True Positives), там модель верно определила людей у которых нет диабета \"No Diabetes\".\n",
"\n",
"Значение 54 у некоторых моделей - это количество неверных диагнозов (False Negatives), там модель неверно определила людей с диабетом, те, у кого нет диабета \"No Diabetes\" были отнесены к классу \"Diabetes\".\n",
"\n",
"Исходя из истинных и ложных значений (True Positives и False Negatives), можно сделать вывод, что модель имеет высокую точность при предсказании людей без диабета \"No Diabetes\". Уровень ложных результатов в некотоорых моделях со значением 54 говорит о том, что есть такие данные, которые модель пропускает.\n",
"\n",
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_fed29_row0_col0, #T_fed29_row0_col1, #T_fed29_row0_col2, #T_fed29_row0_col3, #T_fed29_row1_col0, #T_fed29_row1_col1, #T_fed29_row1_col2, #T_fed29_row1_col3, #T_fed29_row2_col0, #T_fed29_row2_col1, #T_fed29_row2_col2, #T_fed29_row2_col3, #T_fed29_row3_col0, #T_fed29_row3_col1, #T_fed29_row3_col2, #T_fed29_row3_col3, #T_fed29_row4_col0, #T_fed29_row4_col1, #T_fed29_row4_col2, #T_fed29_row4_col3, #T_fed29_row5_col0, #T_fed29_row5_col1, #T_fed29_row5_col2, #T_fed29_row5_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_fed29_row0_col4, #T_fed29_row0_col5, #T_fed29_row0_col6, #T_fed29_row0_col7, #T_fed29_row1_col4, #T_fed29_row1_col5, #T_fed29_row1_col6, #T_fed29_row1_col7, #T_fed29_row2_col4, #T_fed29_row2_col5, #T_fed29_row2_col6, #T_fed29_row2_col7, #T_fed29_row3_col4, #T_fed29_row3_col5, #T_fed29_row3_col6, #T_fed29_row3_col7, #T_fed29_row4_col4, #T_fed29_row4_col5, #T_fed29_row4_col6, #T_fed29_row4_col7, #T_fed29_row5_col4, #T_fed29_row5_col5, #T_fed29_row5_col6, #T_fed29_row5_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fed29_row6_col0 {\n",
" background-color: #86d549;\n",
" color: #000000;\n",
"}\n",
"#T_fed29_row6_col1 {\n",
" background-color: #4ac16d;\n",
" color: #000000;\n",
"}\n",
"#T_fed29_row6_col2 {\n",
" background-color: #75d054;\n",
" color: #000000;\n",
"}\n",
"#T_fed29_row6_col3 {\n",
" background-color: #69cd5b;\n",
" color: #000000;\n",
"}\n",
"#T_fed29_row6_col4 {\n",
" background-color: #c43e7f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fed29_row6_col5 {\n",
" background-color: #ae2892;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fed29_row6_col6 {\n",
" background-color: #cc4977;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fed29_row6_col7 {\n",
" background-color: #c13b82;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fed29_row7_col0, #T_fed29_row7_col1, #T_fed29_row7_col2, #T_fed29_row7_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fed29_row7_col4, #T_fed29_row7_col5, #T_fed29_row7_col6, #T_fed29_row7_col7 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_fed29\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_fed29_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_fed29_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_fed29_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_fed29_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_fed29_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_fed29_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_fed29_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_fed29_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_fed29_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_fed29_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_fed29_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_fed29_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_fed29_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_fed29_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_fed29_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_fed29_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_fed29_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fed29_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_fed29_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_fed29_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_fed29_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_fed29_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_fed29_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_fed29_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_fed29_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_fed29_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fed29_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_fed29_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_fed29_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_fed29_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_fed29_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_fed29_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" <td id=\"T_fed29_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
" <td id=\"T_fed29_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
" <td id=\"T_fed29_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fed29_level0_row3\" class=\"row_heading level0 row3\" >naive_bayes</th>\n",
" <td id=\"T_fed29_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_fed29_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_fed29_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_fed29_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_fed29_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" <td id=\"T_fed29_row3_col5\" class=\"data row3 col5\" >1.000000</td>\n",
" <td id=\"T_fed29_row3_col6\" class=\"data row3 col6\" >1.000000</td>\n",
" <td id=\"T_fed29_row3_col7\" class=\"data row3 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fed29_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
" <td id=\"T_fed29_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_fed29_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_fed29_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_fed29_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_fed29_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
" <td id=\"T_fed29_row4_col5\" class=\"data row4 col5\" >1.000000</td>\n",
" <td id=\"T_fed29_row4_col6\" class=\"data row4 col6\" >1.000000</td>\n",
" <td id=\"T_fed29_row4_col7\" class=\"data row4 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fed29_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
" <td id=\"T_fed29_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
" <td id=\"T_fed29_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
" <td id=\"T_fed29_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
" <td id=\"T_fed29_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
" <td id=\"T_fed29_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
" <td id=\"T_fed29_row5_col5\" class=\"data row5 col5\" >1.000000</td>\n",
" <td id=\"T_fed29_row5_col6\" class=\"data row5 col6\" >1.000000</td>\n",
" <td id=\"T_fed29_row5_col7\" class=\"data row5 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fed29_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
" <td id=\"T_fed29_row6_col0\" class=\"data row6 col0\" >0.947090</td>\n",
" <td id=\"T_fed29_row6_col1\" class=\"data row6 col1\" >0.796296</td>\n",
" <td id=\"T_fed29_row6_col2\" class=\"data row6 col2\" >0.836449</td>\n",
" <td id=\"T_fed29_row6_col3\" class=\"data row6 col3\" >0.796296</td>\n",
" <td id=\"T_fed29_row6_col4\" class=\"data row6 col4\" >0.926710</td>\n",
" <td id=\"T_fed29_row6_col5\" class=\"data row6 col5\" >0.857143</td>\n",
" <td id=\"T_fed29_row6_col6\" class=\"data row6 col6\" >0.888337</td>\n",
" <td id=\"T_fed29_row6_col7\" class=\"data row6 col7\" >0.796296</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fed29_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_fed29_row7_col0\" class=\"data row7 col0\" >0.552632</td>\n",
" <td id=\"T_fed29_row7_col1\" class=\"data row7 col1\" >0.428571</td>\n",
" <td id=\"T_fed29_row7_col2\" class=\"data row7 col2\" >0.098131</td>\n",
" <td id=\"T_fed29_row7_col3\" class=\"data row7 col3\" >0.111111</td>\n",
" <td id=\"T_fed29_row7_col4\" class=\"data row7 col4\" >0.657980</td>\n",
" <td id=\"T_fed29_row7_col5\" class=\"data row7 col5\" >0.636364</td>\n",
" <td id=\"T_fed29_row7_col6\" class=\"data row7 col6\" >0.166667</td>\n",
" <td id=\"T_fed29_row7_col7\" class=\"data row7 col7\" >0.176471</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x135c39b70b0>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Почти все модели в данной выборке, а именно логистическая регрессия, ридж-регрессия, дерево решений, наивный байесовский классификатор, случайный лес, градиентный бустинг, KNN демонстрируют неплохие значения по всем метрикам на обучающих и тестовых наборах данных. Модель MLP не так эффективна по сравнению с другими, но в некоторых метриках показывает неплохие результаты."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_2bffc_row0_col0, #T_2bffc_row0_col1, #T_2bffc_row1_col0, #T_2bffc_row1_col1, #T_2bffc_row2_col0, #T_2bffc_row2_col1, #T_2bffc_row3_col0, #T_2bffc_row3_col1, #T_2bffc_row4_col0, #T_2bffc_row4_col1, #T_2bffc_row5_col0, #T_2bffc_row5_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_2bffc_row0_col2, #T_2bffc_row0_col3, #T_2bffc_row0_col4, #T_2bffc_row1_col2, #T_2bffc_row1_col3, #T_2bffc_row1_col4, #T_2bffc_row2_col2, #T_2bffc_row2_col3, #T_2bffc_row2_col4, #T_2bffc_row3_col2, #T_2bffc_row3_col3, #T_2bffc_row3_col4, #T_2bffc_row4_col2, #T_2bffc_row4_col3, #T_2bffc_row4_col4, #T_2bffc_row5_col2, #T_2bffc_row5_col3, #T_2bffc_row5_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2bffc_row6_col0 {\n",
" background-color: #42be71;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2bffc_row6_col1 {\n",
" background-color: #65cb5e;\n",
" color: #000000;\n",
"}\n",
"#T_2bffc_row6_col2 {\n",
" background-color: #c5407e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2bffc_row6_col3 {\n",
" background-color: #b7318a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2bffc_row6_col4 {\n",
" background-color: #b6308b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2bffc_row7_col0, #T_2bffc_row7_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_2bffc_row7_col2, #T_2bffc_row7_col3, #T_2bffc_row7_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_2bffc\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_2bffc_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_2bffc_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_2bffc_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_2bffc_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_2bffc_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_2bffc_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_2bffc_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_2bffc_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_2bffc_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_2bffc_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_2bffc_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_2bffc_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_2bffc_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_2bffc_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_2bffc_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_2bffc_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_2bffc_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_2bffc_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_2bffc_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_2bffc_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_2bffc_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_2bffc_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_2bffc_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_2bffc_level0_row3\" class=\"row_heading level0 row3\" >naive_bayes</th>\n",
" <td id=\"T_2bffc_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_2bffc_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_2bffc_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_2bffc_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_2bffc_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_2bffc_level0_row4\" class=\"row_heading level0 row4\" >gradient_boosting</th>\n",
" <td id=\"T_2bffc_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_2bffc_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_2bffc_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_2bffc_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_2bffc_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_2bffc_level0_row5\" class=\"row_heading level0 row5\" >random_forest</th>\n",
" <td id=\"T_2bffc_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
" <td id=\"T_2bffc_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
" <td id=\"T_2bffc_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
" <td id=\"T_2bffc_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
" <td id=\"T_2bffc_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_2bffc_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
" <td id=\"T_2bffc_row6_col0\" class=\"data row6 col0\" >0.857143</td>\n",
" <td id=\"T_2bffc_row6_col1\" class=\"data row6 col1\" >0.796296</td>\n",
" <td id=\"T_2bffc_row6_col2\" class=\"data row6 col2\" >0.923333</td>\n",
" <td id=\"T_2bffc_row6_col3\" class=\"data row6 col3\" >0.686296</td>\n",
" <td id=\"T_2bffc_row6_col4\" class=\"data row6 col4\" >0.686296</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_2bffc_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_2bffc_row7_col0\" class=\"data row7 col0\" >0.636364</td>\n",
" <td id=\"T_2bffc_row7_col1\" class=\"data row7 col1\" >0.176471</td>\n",
" <td id=\"T_2bffc_row7_col2\" class=\"data row7 col2\" >0.630556</td>\n",
" <td id=\"T_2bffc_row7_col3\" class=\"data row7 col3\" >0.037500</td>\n",
" <td id=\"T_2bffc_row7_col4\" class=\"data row7 col4\" >0.051640</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x135c3bf80b0>"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Такой же вывод можно сделать и для следующих метрик: Accuracy, F1, ROC AUC, Cohen's Kappa и MCC. Все модели, кроме KNN и MLP, указывают на хорошо-развитую способность к выделению классов"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 0'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Predicted</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [Pregnancies, Predicted, Glucose, BloodPressure, SkinThickness, Insulin, BMI, DiabetesPedigreeFunction, Age, Outcome]\n",
"Index: []"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"Outcome\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Пример использования обученной модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>190</th>\n",
" <td>3.0</td>\n",
" <td>111.0</td>\n",
" <td>62.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>22.6</td>\n",
" <td>0.142</td>\n",
" <td>21.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"190 3.0 111.0 62.0 0.0 0.0 22.6 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"190 0.142 21.0 0.0 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome_1</th>\n",
" <th>Glucose_Insulin</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>190</th>\n",
" <td>-0.24739</td>\n",
" <td>-0.314212</td>\n",
" <td>-0.404784</td>\n",
" <td>-1.31138</td>\n",
" <td>-0.730766</td>\n",
" <td>-1.193296</td>\n",
" <td>-1.016353</td>\n",
" <td>-1.045895</td>\n",
" <td>0.0</td>\n",
" <td>0.429976</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"190 -0.24739 -0.314212 -0.404784 -1.31138 -0.730766 -1.193296 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome_1 Glucose_Insulin \n",
"190 -1.016353 -1.045895 0.0 0.429976 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: 0 (proba: [0.99177252 0.00822748])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 0'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 190\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\tabee\\AIM_PIbd-31_Tabeev_A.P\\.venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 5,\n",
" 'model__max_features': 'sqrt',\n",
" 'model__n_estimators': 10}"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"optimized_model_type = \"random_forest\"\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 50, 100],\n",
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
" \"model__max_depth\": [5, 7, 10],\n",
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"import pandas as pd\n",
"\n",
"numeric_features = X_train.select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
"\n",
"random_state = 42\n",
"# Определение трансформера\n",
"pipeline_end = ColumnTransformer([\n",
" ('numeric', StandardScaler(), numeric_features),\n",
"])\n",
"\n",
"optimized_model = RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"sqrt\",\n",
" n_estimators=50,\n",
")\n",
"\n",
"result = {}\n",
"\n",
"# Обучение модели\n",
"result[\"pipeline\"] = Pipeline([\n",
" (\"pipeline\", pipeline_end),\n",
" (\"model\", optimized_model)\n",
"]).fit(X_train, y_train.values.ravel())\n",
"\n",
"# Прогнозирование и расчет метрик\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"# Метрики для оценки модели\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_02d85_row0_col0, #T_02d85_row0_col1, #T_02d85_row0_col2, #T_02d85_row0_col3, #T_02d85_row1_col0, #T_02d85_row1_col1, #T_02d85_row1_col2, #T_02d85_row1_col3 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_02d85_row0_col4, #T_02d85_row0_col5, #T_02d85_row0_col6, #T_02d85_row0_col7, #T_02d85_row1_col4, #T_02d85_row1_col5, #T_02d85_row1_col6, #T_02d85_row1_col7 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_02d85\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_02d85_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_02d85_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_02d85_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_02d85_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_02d85_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_02d85_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_02d85_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_02d85_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" <th class=\"blank col7\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_02d85_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_02d85_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_02d85_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_02d85_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_02d85_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_02d85_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_02d85_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_02d85_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_02d85_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_02d85_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_02d85_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_02d85_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_02d85_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_02d85_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_02d85_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_02d85_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_02d85_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_02d85_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x135c73282c0>"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_34fcd_row0_col0, #T_34fcd_row0_col1, #T_34fcd_row1_col0, #T_34fcd_row1_col1 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_34fcd_row0_col2, #T_34fcd_row0_col3, #T_34fcd_row0_col4, #T_34fcd_row1_col3, #T_34fcd_row1_col4 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_34fcd_row1_col2 {\n",
" background-color: #f0f921;\n",
" color: #000000;\n",
"}\n",
"</style>\n",
"<table id=\"T_34fcd\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_34fcd_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_34fcd_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_34fcd_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_34fcd_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_34fcd_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_34fcd_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_34fcd_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_34fcd_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_34fcd_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_34fcd_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_34fcd_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_34fcd_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_34fcd_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_34fcd_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_34fcd_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_34fcd_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_34fcd_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x135c7328590>"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA5IAAAGxCAYAAAAQ1omjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABOw0lEQVR4nO3dfVxUZf7/8fcBBLzhRrzhRhE177A0tVxzK+0Gwy0rs77mLruLpba7ZaZmVlugYom2aS5W6mZJ7s/WbFNTK3fLVss0S1OrlbS8CU3RNhWEFgRmfn+wTjsLKsMczgxzXs/H4zyCczfXTMibz7mucx3D6XQ6BQAAAABALQX5ugEAAAAAgIaFQhIAAAAA4BEKSQAAAACARygkAQAAAAAeoZAEAAAAAHiEQhIAAAAA4BEKSQAAAACARygkAQAAAAAeoZAEAAAAAHiEQhIAAAAA4BEKSQBAQHv//fd18803KyEhQYZhaNWqVW7bnU6nMjMzFR8fr8aNGyslJUVfffWV2z4nTpxQWlqaIiMjFR0drVGjRqm4uNjCdwEAQBV/yTUKSQBAQCspKdGll16q5557rsbtTz31lHJycrRgwQJt3bpVTZs2VWpqqkpLS137pKWl6Z///KfeeecdrV27Vu+//77uueceq94CAAAu/pJrhtPpdHr1TgAAaCAMw9DKlSs1dOhQSVVXbRMSEvTggw9q0qRJkqTCwkLFxsYqNzdXI0aMUF5enrp3765PPvlEl19+uSRp3bp1uvHGG3X48GElJCT46u0AAGzOl7kWUi/vCACA/1JaWqozZ86Ydj6n0ynDMNzWhYWFKSwszKPzHDhwQAUFBUpJSXGti4qKUr9+/bRlyxaNGDFCW7ZsUXR0tCtsJSklJUVBQUHaunWrbrvtNu/eDACgwSHXKCQBAPWstLRUHZKaqeB4pWnnbNasWbV7OaZMmaKpU6d6dJ6CggJJUmxsrNv62NhY17aCggK1bt3abXtISIhiYmJc+wAA7INc+88xHrUMAAAPnTlzRgXHK3Vge5IiI7y/Nb/otEMdLvtGhw4dUmRkpGu9p1dtAQCoC3KtCoUkAMASkRFBpgSu63yRkW6BWxdxcXGSpGPHjik+Pt61/tixY+rVq5drn+PHj7sdV1FRoRMnTriOBwDYj91zjVlbAQCWqHQ6TFvM0qFDB8XFxWn9+vWudUVFRdq6dav69+8vSerfv79OnTql7du3u/Z577335HA41K9fP9PaAgBoWOyea/RIAgAs4ZBTDnk/Ubin5yguLtbXX3/t+v7AgQPauXOnYmJi1K5dO40fP15PPPGEOnfurA4dOigjI0MJCQmuGfCSk5M1ePBgjRkzRgsWLFB5ebnGjh2rESNGMGMrANiY3XONQhIAENC2bduma6+91vX9xIkTJUnp6enKzc3V5MmTVVJSonvuuUenTp3SVVddpXXr1ik8PNx1zNKlSzV27Fhdf/31CgoK0u23366cnBzL3wsAAP6SazxHEgBQr4qKihQVFaUje9qaNilBQtfDKiws9PpeEgAAPEWuVaFHEgBgiUqnU5UmXLs04xwAAHjL7rnGZDsAAAAAAI/QIwkAsISvJiUAAKA+2D3XKCQBAJZwyKlKGwcuACCw2D3XGNoKAAAAAPAIPZIAAEvYfQgQACCw2D3X6JEEAAAAAHiEHkkAgCXsPk06ACCw2D3XKCQBAJZw/Gcx4zwAAPia3XONoa0AAAAAAI/QIwkAsESlSdOkm3EOAAC8Zfdco5AEAFii0lm1mHEeAAB8ze65xtBWAAAAAIBH6JEEAFjC7pMSAAACi91zjUISAGAJhwxVyjDlPAAA+Jrdc42hrQAAAAAAj9AjCQCwhMNZtZhxHgAAfM3uuUaPJAAAAADAI/RIAgAsUWnSvSRmnAMAAG/ZPdcoJAEAlrB74AIAAovdc42hrQAAAAAAj9AjCQCwhMNpyOE0YZp0E84BAIC37J5rFJIAAEvYfQgQACCw2D3XGNoKAAAAAPAIPZIAAEtUKkiVJly/rDShLQAAeMvuuUYhCQCwhNOke0mcDfReEgBAYLF7rjG0FQAAAADgEXokAQCWsPukBACAwGL3XKOQBABYotIZpEqnCfeSOE1oDAAAXrJ7rjG0FQAAAADgEXokAQCWcMiQw4Trlw410Eu3AICAYvdco0cSAAAAAOAReiQBAJaw+6QEAIDAYvdco5AEAFjCvEkJGuYQIABAYLF7rjG0FQAAAADgEXokAQCWqJqUwPvhO2acAwAAb9k91ygkAQCWcChIlTae3Q4AEFjsnmsMbQUAAAAAeIQeSQCAJew+KQEAILDYPdcoJAEAlnAoyNYPbgYABBa75xpDWwEAAAAAHqFHEgBgiUqnoUqnCQ9uNuEcAAB4y+65Ro8kAAAAAMAj9EgCACxRadI06ZUN9F4SAEBgsXuuUUgCACzhcAbJYcLsdo4GOrsdACCw2D3XGNoKAAAAAPAIPZIAAEvYfQgQACCw2D3XKCQBAJZwyJyZ6RzeNwUAAK/ZPdcY2goAAAAA8Ag9kgAASzgUJIcJ1y/NOAcAAN6ye65RSAIALFHpDFKlCbPbmXEOAAC8Zfdca5itBgAAAAD4DD2SAABLOGTIITMmJfD+HAAAeMvuuUYhCQCwhN2HAAEAAovdc61hthoAAAAA4DP0SAIALGHeg5u5BgoA8D2751rDbDUAAAAAwGfokbQRh8OhI0eOKCIiQobRMG/qBWAtp9Op06dPKyEhQUFB3l17dDgNOZwmTEpgwjkQGMg1AJ4i18xDIWkjR44cUWJioq+bAaABOnTokNq2bevVORwmDQFqqA9uhvnINQB1Ra55j0LSRiIiIiRJ33zaXpHNGuYPLOrPbV16+LoJ8EMVKtcmveX6/QH4E3IN50OuoSbkmnkoJG3k7LCfyGZBiowgcOEuxGjk6ybAHzmr/mPGsEGHM0gOE6Y4N+McCAzkGs6HXEONyDXTUEgCACxRKUOVJjx02YxzAADgLbvnWsMsfwEAAAAAPkOPJADAEnYfAgQACCx2zzUKSQCAJSplzvCdSu+bAgCA1+yeaw2z/AUAAAAA+Aw9kgAAS9h9CBAAILDYPdcaZqsBAAAAAD5DjyQAwBKVziBVmnDV1YxzAADgLbvnWsNsNQCgwXHKkMOExenBxAaVlZXKyMhQhw4d1LhxY1100UWaPn26nE7nj+1yOpWZman4+Hg1btxYKSkp+uqrr+rjIwAABBC75xqFJAAgYM2aNUvz58/Xs88+q7y8PM2aNUtPPfWU5s2b59rnqaeeUk5OjhYsWKCtW7eqadOmSk1NVWlpqQ9bDgBAdf6UawxtBQBYwhdDgDZv3qxbb71VN910kySpffv2+stf/qKPP/5YUtVV27lz5+rxxx/XrbfeKklasmSJYmNjtWrVKo0YMcLr9gIAApPdc40eSQCAJRxOw7RFkoqKityWsrKyaq/505/+VOvXr9fevXslSbt27dKmTZv0s5/9TJJ04MABFRQUKCUlxXVMVFSU+vXrpy1btljwqQAAGiq75xo9kgCABikxMdHt+ylTpmjq1Klu6x555BEVFRWpW7duCg4OVmVlpZ588kmlpaVJkgoKCiRJsbGxbsfFxsa6tgEAYIWGlmsUkgAAS1QqSJUmDIQ5e45Dhw4pMjLStT4sLKzavsuXL9fSpUv1yiuv6OKLL9bOnTs1fvx4JSQkKD093eu2AADsy+65RiEJALDEfw/f8fY8khQZGekWuDV56KGH9Mgjj7juCenRo4e++eYbZWdnKz09XXFxcZKkY8eOKT4+3nXcsWPH1KtXL6/bCgAIXHbPNe6RBAAErB9++EFBQe5RFxwcLIfDIUnq0KGD4uLitH79etf2oqIibd26Vf3797e0rQAAXIg/5Ro9kgAASzgUJIcJ1y89OcfNN9+sJ598Uu3atdPFF1+sHTt2aM6cObr77rslSYZhaPz48XriiSfUuXNndejQQRkZGUpISNDQoUO9bisAIHDZPdcoJAEAlqh0Gqo0YQiQJ+eYN2+eMjIydO+99+r48eNKSEjQb37zG2VmZrr2mTx5skpKSnTPPffo1KlTuuq
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"No Diabetes\", \"Diabetes\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Значение 100 означает количество верно классифицированных объектов, которые относятся к классу \"No Diabetes\". Можно сделать вывод, что модель отлично идентифицирует объекты этого класса.\n",
"\n",
"Значения 0 означают количество верно классифицированных объектов, которые относятся к классу \"Diabetes\". Можно сделать вывод, что это не высокая точность модели в определении объектов данного класса"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Определение достижимого уровня качества модели для второй задачи (задача регрессии)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(700, 8)\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>6</td>\n",
" <td>98</td>\n",
" <td>58</td>\n",
" <td>33</td>\n",
" <td>190</td>\n",
" <td>34.0</td>\n",
" <td>0.430</td>\n",
" <td>43</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>112</td>\n",
" <td>75</td>\n",
" <td>32</td>\n",
" <td>0</td>\n",
" <td>35.7</td>\n",
" <td>0.148</td>\n",
" <td>21</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2</td>\n",
" <td>108</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30.8</td>\n",
" <td>0.158</td>\n",
" <td>21</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>8</td>\n",
" <td>107</td>\n",
" <td>80</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>24.6</td>\n",
" <td>0.856</td>\n",
" <td>34</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>7</td>\n",
" <td>136</td>\n",
" <td>90</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>29.9</td>\n",
" <td>0.210</td>\n",
" <td>50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>695</th>\n",
" <td>2</td>\n",
" <td>105</td>\n",
" <td>80</td>\n",
" <td>45</td>\n",
" <td>191</td>\n",
" <td>33.7</td>\n",
" <td>0.711</td>\n",
" <td>29</td>\n",
" </tr>\n",
" <tr>\n",
" <th>696</th>\n",
" <td>1</td>\n",
" <td>126</td>\n",
" <td>56</td>\n",
" <td>29</td>\n",
" <td>152</td>\n",
" <td>28.7</td>\n",
" <td>0.801</td>\n",
" <td>21</td>\n",
" </tr>\n",
" <tr>\n",
" <th>697</th>\n",
" <td>2</td>\n",
" <td>95</td>\n",
" <td>54</td>\n",
" <td>14</td>\n",
" <td>88</td>\n",
" <td>26.1</td>\n",
" <td>0.748</td>\n",
" <td>22</td>\n",
" </tr>\n",
" <tr>\n",
" <th>698</th>\n",
" <td>3</td>\n",
" <td>100</td>\n",
" <td>68</td>\n",
" <td>23</td>\n",
" <td>81</td>\n",
" <td>31.6</td>\n",
" <td>0.949</td>\n",
" <td>28</td>\n",
" </tr>\n",
" <tr>\n",
" <th>699</th>\n",
" <td>1</td>\n",
" <td>85</td>\n",
" <td>66</td>\n",
" <td>29</td>\n",
" <td>0</td>\n",
" <td>26.6</td>\n",
" <td>0.351</td>\n",
" <td>31</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>700 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"0 6 98 58 33 190 34.0 \n",
"1 2 112 75 32 0 35.7 \n",
"2 2 108 64 0 0 30.8 \n",
"3 8 107 80 0 0 24.6 \n",
"4 7 136 90 0 0 29.9 \n",
".. ... ... ... ... ... ... \n",
"695 2 105 80 45 191 33.7 \n",
"696 1 126 56 29 152 28.7 \n",
"697 2 95 54 14 88 26.1 \n",
"698 3 100 68 23 81 31.6 \n",
"699 1 85 66 29 0 26.6 \n",
"\n",
" DiabetesPedigreeFunction Age \n",
"0 0.430 43 \n",
"1 0.148 21 \n",
"2 0.158 21 \n",
"3 0.856 34 \n",
"4 0.210 50 \n",
".. ... ... \n",
"695 0.711 29 \n",
"696 0.801 21 \n",
"697 0.748 22 \n",
"698 0.949 28 \n",
"699 0.351 31 \n",
"\n",
"[700 rows x 8 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import set_config\n",
"\n",
"random_state = 42\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"df = pd.read_csv(\"data/diabetes.csv\")\n",
"\n",
"df = df.drop(columns=[\"Outcome\"])\n",
"\n",
"df = df.sample(n=700, random_state=random_state).reset_index(drop=True)\n",
"\n",
"print(df.shape) \n",
"display(df)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"0 6 148 72 35 0 33.6 \n",
"1 1 85 66 29 0 26.6 \n",
"2 8 183 64 0 0 23.3 \n",
"3 1 89 66 23 94 28.1 \n",
"4 0 137 40 35 168 43.1 \n",
"\n",
" DiabetesPedigreeFunction Age diabetes_risk_index \n",
"0 0.627 50 71.68 \n",
"1 0.351 31 46.28 \n",
"2 0.672 32 74.69 \n",
"3 0.167 21 55.33 \n",
"4 2.288 33 81.43 \n"
]
}
],
"source": [
"import pandas as pd\n",
"\n",
"df = pd.read_csv(\"data/diabetes.csv\")\n",
"\n",
"required_columns = [\"Pregnancies\", \"Glucose\", \"BloodPressure\", \"SkinThickness\", \"Insulin\", \"BMI\", \"DiabetesPedigreeFunction\", \"Age\"]\n",
"missing_columns = [col for col in required_columns if col not in df.columns]\n",
"if missing_columns:\n",
" raise ValueError(f\"Отсутствуют столбцы: {missing_columns}\")\n",
"\n",
"df[\"diabetes_risk_index\"] = (\n",
" df[\"Glucose\"] * 0.3 \n",
" + df[\"BMI\"] * 0.3 \n",
" + df[\"Age\"] * 0.2 \n",
" + df[\"BloodPressure\"] * 0.1 \n",
" + df[\"Insulin\"] * 0.1 \n",
")\n",
"\n",
"print(df[[\"Pregnancies\", \"Glucose\", \"BloodPressure\", \"SkinThickness\", \"Insulin\", \"BMI\", \"DiabetesPedigreeFunction\", \"Age\", \"diabetes_risk_index\"]].head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи регрессии"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>60</th>\n",
" <td>2</td>\n",
" <td>84</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.304</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>618</th>\n",
" <td>9</td>\n",
" <td>112</td>\n",
" <td>82</td>\n",
" <td>24</td>\n",
" <td>0</td>\n",
" <td>28.2</td>\n",
" <td>1.282</td>\n",
" <td>50</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>346</th>\n",
" <td>1</td>\n",
" <td>139</td>\n",
" <td>46</td>\n",
" <td>19</td>\n",
" <td>83</td>\n",
" <td>28.7</td>\n",
" <td>0.654</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>294</th>\n",
" <td>0</td>\n",
" <td>161</td>\n",
" <td>50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>21.9</td>\n",
" <td>0.254</td>\n",
" <td>65</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>231</th>\n",
" <td>6</td>\n",
" <td>134</td>\n",
" <td>80</td>\n",
" <td>37</td>\n",
" <td>370</td>\n",
" <td>46.2</td>\n",
" <td>0.238</td>\n",
" <td>46</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"60 2 84 0 0 0 0.0 \n",
"618 9 112 82 24 0 28.2 \n",
"346 1 139 46 19 83 28.7 \n",
"294 0 161 50 0 0 21.9 \n",
"231 6 134 80 37 370 46.2 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"60 0.304 21 0 \n",
"618 1.282 50 1 \n",
"346 0.654 22 0 \n",
"294 0.254 65 0 \n",
"231 0.238 46 1 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>diabetes_risk_index</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>60</th>\n",
" <td>29.40</td>\n",
" </tr>\n",
" <tr>\n",
" <th>618</th>\n",
" <td>60.26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>346</th>\n",
" <td>67.61</td>\n",
" </tr>\n",
" <tr>\n",
" <th>294</th>\n",
" <td>72.87</td>\n",
" </tr>\n",
" <tr>\n",
" <th>231</th>\n",
" <td>108.26</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" diabetes_risk_index\n",
"60 29.40\n",
"618 60.26\n",
"346 67.61\n",
"294 72.87\n",
"231 108.26"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>668</th>\n",
" <td>6</td>\n",
" <td>98</td>\n",
" <td>58</td>\n",
" <td>33</td>\n",
" <td>190</td>\n",
" <td>34.0</td>\n",
" <td>0.430</td>\n",
" <td>43</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>324</th>\n",
" <td>2</td>\n",
" <td>112</td>\n",
" <td>75</td>\n",
" <td>32</td>\n",
" <td>0</td>\n",
" <td>35.7</td>\n",
" <td>0.148</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>624</th>\n",
" <td>2</td>\n",
" <td>108</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30.8</td>\n",
" <td>0.158</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>690</th>\n",
" <td>8</td>\n",
" <td>107</td>\n",
" <td>80</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>24.6</td>\n",
" <td>0.856</td>\n",
" <td>34</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>473</th>\n",
" <td>7</td>\n",
" <td>136</td>\n",
" <td>90</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>29.9</td>\n",
" <td>0.210</td>\n",
" <td>50</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"668 6 98 58 33 190 34.0 \n",
"324 2 112 75 32 0 35.7 \n",
"624 2 108 64 0 0 30.8 \n",
"690 8 107 80 0 0 24.6 \n",
"473 7 136 90 0 0 29.9 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"668 0.430 43 0 \n",
"324 0.148 21 0 \n",
"624 0.158 21 0 \n",
"690 0.856 34 0 \n",
"473 0.210 50 0 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>diabetes_risk_index</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>668</th>\n",
" <td>73.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>324</th>\n",
" <td>56.01</td>\n",
" </tr>\n",
" <tr>\n",
" <th>624</th>\n",
" <td>52.24</td>\n",
" </tr>\n",
" <tr>\n",
" <th>690</th>\n",
" <td>54.28</td>\n",
" </tr>\n",
" <tr>\n",
" <th>473</th>\n",
" <td>68.77</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" diabetes_risk_index\n",
"668 73.00\n",
"324 56.01\n",
"624 52.24\n",
"690 54.28\n",
"473 68.77"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_into_train_test(\n",
" df_input: DataFrame,\n",
" target_colname: str = \"diabetes_risk_index\",\n",
" frac_train: float = 0.8,\n",
" random_state: int = None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" if not (0 < frac_train < 1):\n",
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
" \n",
" if target_colname not in df_input.columns:\n",
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
" \n",
" X = df_input.drop(columns=[target_colname])\n",
" y = df_input[[target_colname]]\n",
"\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y,\n",
" test_size=(1.0 - frac_train),\n",
" random_state=random_state\n",
" )\n",
" \n",
" return X_train, X_test, y_train, y_test\n",
"\n",
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
" df, \n",
" target_colname=\"diabetes_risk_index\", \n",
" frac_train=0.8, \n",
" random_state=42\n",
")\n",
"\n",
"display(\"X_train\", X_train.head())\n",
"display(\"y_train\", y_train.head())\n",
"\n",
"display(\"X_test\", X_test.head())\n",
"display(\"y_test\", y_test.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Определение перечня алгоритмов решения задачи аппроксимации (регрессии)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures, StandardScaler\n",
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
"\n",
"random_state = 42\n",
"\n",
"models = {\n",
" # Линейная регрессия\n",
" \"linear\": {\n",
" \"model\": linear_model.LinearRegression(n_jobs=-1)\n",
" },\n",
" # Полиномиальная регрессия степени 2\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" StandardScaler(),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" # Полиномиальная регрессия с взаимодействиями\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" StandardScaler(),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" # Ridge-регрессия\n",
" \"ridge\": {\n",
" \"model\": make_pipeline(\n",
" StandardScaler(),\n",
" linear_model.RidgeCV()\n",
" )\n",
" },\n",
" # Регрессия на основе дерева решений\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
" },\n",
" # Метод ближайших соседей (kNN)\n",
" \"knn\": {\n",
" \"model\": make_pipeline(\n",
" StandardScaler(),\n",
" neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)\n",
" )\n",
" },\n",
" # Случайный лес (Random Forest)\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
" # Нейронная сеть (MLPRegressor)\n",
" \"mlp\": {\n",
" \"model\": make_pipeline(\n",
" StandardScaler(),\n",
" neural_network.MLPRegressor(\n",
" activation=\"tanh\",\n",
" hidden_layer_sizes=(3,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение и оценка моделей с помощью различных алгоритмов"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: linear\n",
"Model: linear_poly\n",
"Model: linear_interact\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: random_forest\n",
"Model: mlp\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\tabee\\AIM_PIbd-31_Tabeev_A.P\\.venv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:690: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (500) reached and the optimization hasn't converged yet.\n",
" warnings.warn(\n"
]
}
],
"source": [
"import math\n",
"from pandas import DataFrame\n",
"from sklearn import metrics\n",
"\n",
"for model_name in models.keys():\n",
" print(f\"Model: {model_name}\")\n",
"\n",
" fitted_model = models[model_name][\"model\"].fit(\n",
" X_train.values, y_train.values.ravel()\n",
" )\n",
" y_train_pred = fitted_model.predict(X_train.values)\n",
" y_test_pred = fitted_model.predict(X_test.values)\n",
" models[model_name][\"fitted\"] = fitted_model\n",
" models[model_name][\"train_preds\"] = y_train_pred\n",
" models[model_name][\"preds\"] = y_test_pred\n",
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_train, y_train_pred)\n",
" )\n",
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод результатов оценки"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_4f71e_row0_col0, #T_4f71e_row0_col1, #T_4f71e_row1_col0, #T_4f71e_row1_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4f71e_row0_col2, #T_4f71e_row1_col2, #T_4f71e_row7_col3 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4f71e_row0_col3, #T_4f71e_row1_col3, #T_4f71e_row2_col3, #T_4f71e_row3_col3, #T_4f71e_row7_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4f71e_row2_col0 {\n",
" background-color: #25838e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4f71e_row2_col1 {\n",
" background-color: #25858e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4f71e_row2_col2 {\n",
" background-color: #7100a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4f71e_row3_col0 {\n",
" background-color: #25848e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4f71e_row3_col1 {\n",
" background-color: #24878e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4f71e_row3_col2 {\n",
" background-color: #7501a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4f71e_row4_col0 {\n",
" background-color: #23888e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4f71e_row4_col1 {\n",
" background-color: #23898e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4f71e_row4_col2 {\n",
" background-color: #7a02a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4f71e_row4_col3, #T_4f71e_row6_col2 {\n",
" background-color: #d9586a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4f71e_row5_col0, #T_4f71e_row5_col1 {\n",
" background-color: #89d548;\n",
" color: #000000;\n",
"}\n",
"#T_4f71e_row5_col2 {\n",
" background-color: #d24f71;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4f71e_row5_col3 {\n",
" background-color: #7201a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4f71e_row6_col0, #T_4f71e_row7_col0, #T_4f71e_row7_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_4f71e_row6_col1 {\n",
" background-color: #a2da37;\n",
" color: #000000;\n",
"}\n",
"#T_4f71e_row6_col3 {\n",
" background-color: #5302a3;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_4f71e\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_4f71e_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
" <th id=\"T_4f71e_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
" <th id=\"T_4f71e_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
" <th id=\"T_4f71e_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_4f71e_level0_row0\" class=\"row_heading level0 row0\" >linear</th>\n",
" <td id=\"T_4f71e_row0_col0\" class=\"data row0 col0\" >0.000000</td>\n",
" <td id=\"T_4f71e_row0_col1\" class=\"data row0 col1\" >0.000000</td>\n",
" <td id=\"T_4f71e_row0_col2\" class=\"data row0 col2\" >0.000000</td>\n",
" <td id=\"T_4f71e_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4f71e_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_4f71e_row1_col0\" class=\"data row1 col0\" >0.002409</td>\n",
" <td id=\"T_4f71e_row1_col1\" class=\"data row1 col1\" >0.002181</td>\n",
" <td id=\"T_4f71e_row1_col2\" class=\"data row1 col2\" >0.040847</td>\n",
" <td id=\"T_4f71e_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4f71e_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
" <td id=\"T_4f71e_row2_col0\" class=\"data row2 col0\" >1.856691</td>\n",
" <td id=\"T_4f71e_row2_col1\" class=\"data row2 col1\" >3.387390</td>\n",
" <td id=\"T_4f71e_row2_col2\" class=\"data row2 col2\" >1.612410</td>\n",
" <td id=\"T_4f71e_row2_col3\" class=\"data row2 col3\" >0.967338</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4f71e_level0_row3\" class=\"row_heading level0 row3\" >decision_tree</th>\n",
" <td id=\"T_4f71e_row3_col0\" class=\"data row3 col0\" >2.409987</td>\n",
" <td id=\"T_4f71e_row3_col1\" class=\"data row3 col1\" >4.281378</td>\n",
" <td id=\"T_4f71e_row3_col2\" class=\"data row3 col2\" >1.847379</td>\n",
" <td id=\"T_4f71e_row3_col3\" class=\"data row3 col3\" >0.947823</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4f71e_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
" <td id=\"T_4f71e_row4_col0\" class=\"data row4 col0\" >5.170083</td>\n",
" <td id=\"T_4f71e_row4_col1\" class=\"data row4 col1\" >5.824576</td>\n",
" <td id=\"T_4f71e_row4_col2\" class=\"data row4 col2\" >2.073599</td>\n",
" <td id=\"T_4f71e_row4_col3\" class=\"data row4 col3\" >0.903431</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4f71e_level0_row5\" class=\"row_heading level0 row5\" >mlp</th>\n",
" <td id=\"T_4f71e_row5_col0\" class=\"data row5 col0\" >60.594447</td>\n",
" <td id=\"T_4f71e_row5_col1\" class=\"data row5 col1\" >59.994001</td>\n",
" <td id=\"T_4f71e_row5_col2\" class=\"data row5 col2\" >7.553087</td>\n",
" <td id=\"T_4f71e_row5_col3\" class=\"data row5 col3\" >-9.245313</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4f71e_level0_row6\" class=\"row_heading level0 row6\" >linear_poly</th>\n",
" <td id=\"T_4f71e_row6_col0\" class=\"data row6 col0\" >67.720820</td>\n",
" <td id=\"T_4f71e_row6_col1\" class=\"data row6 col1\" >66.518485</td>\n",
" <td id=\"T_4f71e_row6_col2\" class=\"data row6 col2\" >8.144355</td>\n",
" <td id=\"T_4f71e_row6_col3\" class=\"data row6 col3\" >-11.594887</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4f71e_level0_row7\" class=\"row_heading level0 row7\" >linear_interact</th>\n",
" <td id=\"T_4f71e_row7_col0\" class=\"data row7 col0\" >67.518306</td>\n",
" <td id=\"T_4f71e_row7_col1\" class=\"data row7 col1\" >67.518306</td>\n",
" <td id=\"T_4f71e_row7_col2\" class=\"data row7 col2\" >8.216952</td>\n",
" <td id=\"T_4f71e_row7_col3\" class=\"data row7 col3\" >-11.976353</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x135c8320bc0>"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод реального и \"спрогнозированного\" результата для обучающей и тестовой выборок\n",
"\n",
"Получение лучшей модели"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'linear'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 27 candidates, totalling 81 fits\n",
"Лучшие параметры: {'max_depth': 20, 'min_samples_split': 5, 'n_estimators': 150}\n",
"Лучший результат (MSE): 11.258082426680579\n",
"Ошибка на тестовой выборке (MSE): 9.1015\n",
"Коэффициент детерминации (R²): 0.9741\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# Удаление пропущенных значений (если есть)\n",
"df.dropna(inplace=True)\n",
"\n",
"X = df[[\"Pregnancies\", \"Glucose\", \"BloodPressure\", \"SkinThickness\", \"Insulin\", \n",
" \"BMI\", \"DiabetesPedigreeFunction\", \"Age\"]] \n",
"y = df[\"diabetes_risk_index\"] \n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y, test_size=0.2, random_state=42\n",
")\n",
"\n",
"model = RandomForestRegressor(random_state=42)\n",
"\n",
"param_grid = {\n",
" 'n_estimators': [50, 100, 150], \n",
" 'max_depth': [10, 20, 30], \n",
" 'min_samples_split': [5, 10, 15] \n",
"}\n",
"\n",
"grid_search = GridSearchCV(\n",
" estimator=model,\n",
" param_grid=param_grid,\n",
" scoring='neg_mean_squared_error', \n",
" cv=3, \n",
" n_jobs=-1, \n",
" verbose=2 \n",
")\n",
"\n",
"grid_search.fit(X_train, y_train)\n",
"\n",
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
"print(\"Лучший результат (MSE):\", -grid_search.best_score_)\n",
"\n",
"best_model = grid_search.best_estimator_\n",
"y_pred = best_model.predict(X_test)\n",
"\n",
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
"r2 = metrics.r2_score(y_test, y_pred)\n",
"\n",
"print(f\"Ошибка на тестовой выборке (MSE): {mse:.4f}\")\n",
"print(f\"Коэффициент детерминации (R²): {r2:.4f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами и сравнение новых и старых данных"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 8 candidates, totalling 24 fits\n",
"Старые параметры: {'max_depth': 20, 'min_samples_split': 5, 'n_estimators': 100}\n",
"Лучший результат (MSE) на старых параметрах: 11.682596479300793\n",
"\n",
"\n",
"Новые параметры: {'max_depth': 20, 'min_samples_split': 10, 'n_estimators': 100}\n",
"Лучший результат (MSE) на новых параметрах: 18.55198575928597\n",
"Среднеквадратическая ошибка (MSE) на тестовых данных: 10.739599657760765\n",
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 3.27713284103052\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"old_param_grid = {\n",
" 'n_estimators': [50, 100],\n",
" 'max_depth': [ 10, 20],\n",
" 'min_samples_split': [5, 10]\n",
"}\n",
"\n",
"old_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
" param_grid=old_param_grid,\n",
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
"\n",
"old_grid_search.fit(X_train, y_train)\n",
"\n",
"old_best_params = old_grid_search.best_params_\n",
"old_best_mse = -old_grid_search.best_score_\n",
"\n",
"new_param_grid = {\n",
" 'n_estimators': [100],\n",
" 'max_depth': [20],\n",
" 'min_samples_split': [10]\n",
"}\n",
"\n",
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
" param_grid=new_param_grid,\n",
" scoring='neg_mean_squared_error', cv=2)\n",
"\n",
"new_grid_search.fit(X_train, y_train)\n",
"\n",
"new_best_params = new_grid_search.best_params_\n",
"new_best_mse = -new_grid_search.best_score_\n",
"\n",
"model_best = RandomForestRegressor(**new_best_params)\n",
"model_best.fit(X_train, y_train)\n",
"\n",
"model_oldbest = RandomForestRegressor(**old_best_params)\n",
"model_oldbest.fit(X_train, y_train)\n",
"\n",
"y_pred = model_best.predict(X_test)\n",
"y_oldpred = model_oldbest.predict(X_test)\n",
"\n",
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"\n",
"print(\"Старые параметры:\", old_best_params)\n",
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
"print(\"\\n\")\n",
"print(\"Новые параметры:\", new_best_params)\n",
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAHWCAYAAABjUYhTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzde1xUdf748dcMAzMwzAzCAIIwgTcYtEy0zEtpW5u23aBau7ibdr9T36ys3dKs7bLbZY1q29y1zLbdsi3YLr91t5tdXLNUKpUhrziAKAyXYRiYAWbm98fI5AAaIjhc3s/Hg4fOOWfO+Zwzt/P+XN4fhc/n8yGEEEIIIYQQotcpQ10AIYQQQgghhBisJOASQgghhBBCiD4iAZcQQgghhBBC9BEJuIQQQgghhBCij0jAJYQQQgghhBB9RAIuIYQQQgghhOgjEnAJIYQQQgghRB+RgEsIIYQQQggh+ogEXEIIIYQQQgjRRyTgEkIIIYQQQog+IgGXEEIchZUrV6JQKNi4cWOX62fNmsX48eOPc6mEEEII0V9JwCWEEEIIIYQQfUQCLiGEEEIIIYToIxJwCSFEH2tra+ORRx5h1KhRqNVq0tLS+M1vfoPb7Q7aLi0tDYVCgUKhQKlUMnz4cC677DKsVmtgm9LSUhQKBU899dRhj/fQQw+hUCg6Lf/b3/7GpEmTiIyMJDY2lssvv5yysrKfLH9X+2tsbGT48OEoFArWrl37k/uoqKjg2muvJTk5GbVaTXp6OjfffDMtLS2BbppH+lu5ciUA33//PQsWLGDkyJFoNBqGDx/ONddcQ01NTZdlLikpYe7cuej1euLi4rjjjjtwuVxB2yoUCm677bbDlr29fKWlpUHL//3vf3P66aej1WrR6XScd955bNu27SevxU+d70MPPRTYdu/evdxyyy1kZGQQGRlJXFwcv/zlLzuVpX2fn3/+OTfeeCNxcXHo9Xquuuoq6urqgrb917/+xXnnnRd4LUaNGsUjjzyCx+MJ2m7WrFkoFApycnI6ncONN96IQqHo1H3W6/WybNkyxo0bh0ajITExkRtvvDGoDIe+z7v6S0tLA4Lf63/84x854YQTiIyMZObMmWzdujXouAsWLCA6OvqI173jtRVCiONFFeoCCCHEQGS327HZbJ2Wt7a2dlp23XXX8eqrr3LppZeycOFCNmzYwOOPP47FYqGgoCBo29NPP50bbrgBr9fL1q1bWbZsGfv27eOLL744pvI++uijPPjgg8ydO5frrruO6upqnnvuOc444wyKioqIiYk5qv09/fTTHDhwoFvb7tu3j1NPPZX6+npuuOEGMjMzqaio4J///CdNTU2cccYZvPbaa0FlBfjtb38bWDZt2jQAPvzwQ3bv3s3VV1/N8OHD2bZtG8uXL2fbtm189dVXnQLDuXPnkpaWxuOPP85XX31Ffn4+dXV1rFq16qjOt6PXXnuN+fPnM3v2bH7/+9/T1NTEiy++yIwZMygqKgoEDUfy8MMPk56eHnjc2NjIzTffHLTNN998w//+9z8uv/xyUlJSKC0t5cUXX2TWrFkUFxcTFRUVtP1tt91GTEwMDz30ED/88AMvvvgie/fuZe3atYFrs3LlSqKjo7nrrruIjo7mk08+YfHixTQ0NPDkk08G7U+j0fDBBx9QVVVFQkICAM3Nzbz55ptoNJpO53TjjTeycuVKrr76avLy8tizZw/PP/88RUVFrFu3jvDwcJYtW0ZjYyMAFouFxx57jN/85jeYzWaAToHTqlWrcDgc3HrrrbhcLp599ll+9rOfsWXLFhITE3/yOgshRMj5hBBCdNsrr7ziA474N27cuMD23377rQ/wXXfddUH7ufvuu32A75NPPgksO+GEE3zz588P2u7KK6/0RUVFBR7v2bPHB/iefPLJw5ZxyZIlvkO/3ktLS31hYWG+Rx99NGi7LVu2+FQqVaflP7W/qqoqn06n85177rk+wPfpp58e8flXXXWVT6lU+r755ptO67xeb6dlM2fO9M2cObPLfTU1NXVa9o9//MMH+D7//PNOZb7wwguDtr3lllt8gO+7774LLAN8t95662HL3/6a79mzx+fz+XwOh8MXExPju/7664O2279/v89gMHRafrj9dbwe1dXVPsC3ZMmSI57v+vXrfYBv1apVnfY5adIkX0tLS2D5H/7wBx/g+9e//nXEfd54442+qKgon8vlCiybOXOmb9y4cb6TTjrJ99RTTwWWv/baa76UlBTf6aefHvRe/+KLL3yA7/XXXw/a95o1a7pc7vP5fJ9++ulh30Pt7/XIyEhfeXl5YPmGDRt8gO///u//Asvmz5/v02q1nfZxqI7XVgghjhfpUiiEED3wwgsv8OGHH3b6O+mkk4K2+3//7/8BcNdddwUtX7hwIQAffPBB0HK3243NZqOqqooPP/yQTz75hLPOOqvT8ZuamrDZbNTV1eHz+Y5Y1nfeeQev18vcuXOx2WyBv+HDhzNmzBg+/fTTozr3Rx55BIPBQF5e3k9u6/V6KSws5IILLmDy5Mmd1nfV9fFIIiMjA/93uVzYbDZOO+00ADZv3txp+1tvvTXo8e233w78+Lp03FdNTQ1er/eIZfjwww+pr6/niiuuCLqeYWFhTJky5aiv55Ecer6tra3U1NQwevRoYmJiujzfG264gfDw8MDjm2++GZVKFXS+h+7T4XBgs9k4/fTTaWpqoqSkpNM+r776al555ZXA41deeYX58+ejVAbfQrz11lsYDAZ+/vOfB12XSZMmER0d3ePrkpOTw4gRIwKPTz31VKZMmdLpNQQCx+zYbVQIIUJJuhQKIUQPnHrqqV0GEMOGDQvqarh3716USiWjR48O2m748OHExMSwd+/eoOVvvPEGb7zxRuDxKaecwl//+tdOx1myZAlLliwB/N2+fvazn7Fs2TLGjBnTadsdO3bg8/m6XAcE3aD/lD179vDSSy/x4osvdtmlrKPq6moaGhp6LVV+bW0tS5cu5Y033qCqqipond1u77R9x3MeNWoUSqWy0xioFStWsGLFCgAiIiKYMmUKzzzzTJev8Y4dOwD42c9+1mUZ9Xp9t8/npzQ3N/P444/zyiuvUFFRERRcd+d8o6OjSUpKCjrfbdu28cADD/DJJ5/Q0NAQtH1X+5w3bx733nsvX3/9NQkJCaxdu5aXXnqJL7/8Mmi7HTt2YLfbA10PO+r4enVXV+/bsWPHsnr16qBlTqeT+Pj4wOPU1FQWLlzIHXfc0aPjCiFEb5GASwghjoPutuScc8453HPPPQCUl5fz+9//njPPPJONGzcGtUzccMMN/PKXv8Tj8WCxWHjooYfIycnpMmmD1+tFoVDw73//m7CwsE7rfyrZwKF++9vfMmbMGObPn3/M48p6Yu7cufzvf//jnnvu4eSTTyY6Ohqv18ucOXN+smUKDv86XHTRRdx22234fD727NnDww8/zPnnnx8Irg7VfpzXXnuN4cOHd1qvUvXeT+vtt9/OK6+8wp133snUqVMxGAwoFAouv/zybp1vR/X19cycORO9Xs/DDz/MqFGj0Gg0bN68mUWLFnW5z/j4eC644AJeeeUVEhMTmT59eqcKBPBfl4SEBF5//fUuj31oMNQXNBoN7733HuBvuXv55Ze58847SUpKYu7cuX16bCGEOBIJuIQQog+dcMIJeL1eduzYEUgKAHDgwAHq6+s54YQTgrZPSkri7LPPDjzOyMhg2rRpFBYWcsUVVwSWjxkzJrDd7NmzaWpq4re//W1QRsN2o0aNwufzkZ6eztixY3t8LkVFRbzxxhsUFhZ2Gbh1JT4+Hr1e3ymrXE/U1dXx8ccfs3TpUhYvXhxY3lVQdOi6QxNT7Ny5E6/X2ympRUpKStB1j46OZt68eRQVFXXa56hRowBISEgIek5f+Oc//8n8+fN5+umnA8tcLhf19fVdbr9jxw7OPPPMwOPGxkYqKyv5xS9+AcDatWupqanhnXfe4Ywzzghst2fPniOW45prrmHevHkYDIbDZvobNWoUH330EdOnTw+qHDhWXb2+27dv7/QahoWFBb0e5513HrGxsaxZs0Y
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10, 5))\n",
"plt.scatter(range(len(y_test)), y_test, label=\"Актуальные значения\", color=\"green\", alpha=0.5)\n",
"plt.scatter(range(len(y_test)), y_pred, label=\"Новые параметры\", color=\"blue\", alpha=0.5)\n",
"plt.scatter(range(len(y_test)), y_oldpred, label=\"Старые параметры\", color=\"red\", alpha=0.5)\n",
"plt.xlabel(\"Выборка\")\n",
"plt.ylabel(\"Значения\")\n",
"plt.legend()\n",
"plt.title(\"Новые и старые параметры\")\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}