3832 lines
377 KiB
Plaintext
3832 lines
377 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Подключили датафрейм и выгрузили данные"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Index(['Pregnancies', 'Glucose', 'BloodPressure', 'SkinThickness', 'Insulin',\n",
|
||
" 'BMI', 'DiabetesPedigreeFunction', 'Age', 'Outcome'],\n",
|
||
" dtype='object')\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Pregnancies</th>\n",
|
||
" <th>Glucose</th>\n",
|
||
" <th>BloodPressure</th>\n",
|
||
" <th>SkinThickness</th>\n",
|
||
" <th>Insulin</th>\n",
|
||
" <th>BMI</th>\n",
|
||
" <th>DiabetesPedigreeFunction</th>\n",
|
||
" <th>Age</th>\n",
|
||
" <th>Outcome</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>6</td>\n",
|
||
" <td>148</td>\n",
|
||
" <td>72</td>\n",
|
||
" <td>35</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>33.6</td>\n",
|
||
" <td>0.627</td>\n",
|
||
" <td>50</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>1</td>\n",
|
||
" <td>85</td>\n",
|
||
" <td>66</td>\n",
|
||
" <td>29</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>26.6</td>\n",
|
||
" <td>0.351</td>\n",
|
||
" <td>31</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>8</td>\n",
|
||
" <td>183</td>\n",
|
||
" <td>64</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>23.3</td>\n",
|
||
" <td>0.672</td>\n",
|
||
" <td>32</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>1</td>\n",
|
||
" <td>89</td>\n",
|
||
" <td>66</td>\n",
|
||
" <td>23</td>\n",
|
||
" <td>94</td>\n",
|
||
" <td>28.1</td>\n",
|
||
" <td>0.167</td>\n",
|
||
" <td>21</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>0</td>\n",
|
||
" <td>137</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>35</td>\n",
|
||
" <td>168</td>\n",
|
||
" <td>43.1</td>\n",
|
||
" <td>2.288</td>\n",
|
||
" <td>33</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>763</th>\n",
|
||
" <td>10</td>\n",
|
||
" <td>101</td>\n",
|
||
" <td>76</td>\n",
|
||
" <td>48</td>\n",
|
||
" <td>180</td>\n",
|
||
" <td>32.9</td>\n",
|
||
" <td>0.171</td>\n",
|
||
" <td>63</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>764</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>122</td>\n",
|
||
" <td>70</td>\n",
|
||
" <td>27</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>36.8</td>\n",
|
||
" <td>0.340</td>\n",
|
||
" <td>27</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>765</th>\n",
|
||
" <td>5</td>\n",
|
||
" <td>121</td>\n",
|
||
" <td>72</td>\n",
|
||
" <td>23</td>\n",
|
||
" <td>112</td>\n",
|
||
" <td>26.2</td>\n",
|
||
" <td>0.245</td>\n",
|
||
" <td>30</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>766</th>\n",
|
||
" <td>1</td>\n",
|
||
" <td>126</td>\n",
|
||
" <td>60</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>30.1</td>\n",
|
||
" <td>0.349</td>\n",
|
||
" <td>47</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>767</th>\n",
|
||
" <td>1</td>\n",
|
||
" <td>93</td>\n",
|
||
" <td>70</td>\n",
|
||
" <td>31</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>30.4</td>\n",
|
||
" <td>0.315</td>\n",
|
||
" <td>23</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>768 rows × 9 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
||
"0 6 148 72 35 0 33.6 \n",
|
||
"1 1 85 66 29 0 26.6 \n",
|
||
"2 8 183 64 0 0 23.3 \n",
|
||
"3 1 89 66 23 94 28.1 \n",
|
||
"4 0 137 40 35 168 43.1 \n",
|
||
".. ... ... ... ... ... ... \n",
|
||
"763 10 101 76 48 180 32.9 \n",
|
||
"764 2 122 70 27 0 36.8 \n",
|
||
"765 5 121 72 23 112 26.2 \n",
|
||
"766 1 126 60 0 0 30.1 \n",
|
||
"767 1 93 70 31 0 30.4 \n",
|
||
"\n",
|
||
" DiabetesPedigreeFunction Age Outcome \n",
|
||
"0 0.627 50 1 \n",
|
||
"1 0.351 31 0 \n",
|
||
"2 0.672 32 1 \n",
|
||
"3 0.167 21 0 \n",
|
||
"4 2.288 33 1 \n",
|
||
".. ... ... ... \n",
|
||
"763 0.171 63 0 \n",
|
||
"764 0.340 27 0 \n",
|
||
"765 0.245 30 0 \n",
|
||
"766 0.349 47 1 \n",
|
||
"767 0.315 23 0 \n",
|
||
"\n",
|
||
"[768 rows x 9 columns]"
|
||
]
|
||
},
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"from sklearn import set_config\n",
|
||
"set_config(transform_output=\"pandas\")\n",
|
||
"random_state = 42\n",
|
||
"df = pd.read_csv(\"data/diabetes.csv\")\n",
|
||
"print(df.columns)\n",
|
||
"df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Бизнес-цели\n",
|
||
"1. Предсказание риска развития диабета. Будем классифицировать пациентов на основе медданных для того, чтобы определить у кого есть риск развития диабета(будем использовать целевой признак \"Outcome\"). Актуальность для раннего выявления диабета.\n",
|
||
"2. Анализ ключевых факторов, влияющих на диабет. Предсказание вероятности развития диабета на основе медданных. Актуальность для планирвоания лечения.\n",
|
||
"## Определение достижимого уровня качества модели для первой задачи\n",
|
||
"Разделение данных на обучающую и тестовые выборки 80/20 для задачи классификации"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Pregnancies</th>\n",
|
||
" <th>Glucose</th>\n",
|
||
" <th>BloodPressure</th>\n",
|
||
" <th>SkinThickness</th>\n",
|
||
" <th>Insulin</th>\n",
|
||
" <th>BMI</th>\n",
|
||
" <th>DiabetesPedigreeFunction</th>\n",
|
||
" <th>Age</th>\n",
|
||
" <th>Outcome</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>353</th>\n",
|
||
" <td>1</td>\n",
|
||
" <td>90</td>\n",
|
||
" <td>62</td>\n",
|
||
" <td>12</td>\n",
|
||
" <td>43</td>\n",
|
||
" <td>27.2</td>\n",
|
||
" <td>0.580</td>\n",
|
||
" <td>24</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>711</th>\n",
|
||
" <td>5</td>\n",
|
||
" <td>126</td>\n",
|
||
" <td>78</td>\n",
|
||
" <td>27</td>\n",
|
||
" <td>22</td>\n",
|
||
" <td>29.6</td>\n",
|
||
" <td>0.439</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>373</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>105</td>\n",
|
||
" <td>58</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>94</td>\n",
|
||
" <td>34.9</td>\n",
|
||
" <td>0.225</td>\n",
|
||
" <td>25</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>46</th>\n",
|
||
" <td>1</td>\n",
|
||
" <td>146</td>\n",
|
||
" <td>56</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>29.7</td>\n",
|
||
" <td>0.564</td>\n",
|
||
" <td>29</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>682</th>\n",
|
||
" <td>0</td>\n",
|
||
" <td>95</td>\n",
|
||
" <td>64</td>\n",
|
||
" <td>39</td>\n",
|
||
" <td>105</td>\n",
|
||
" <td>44.6</td>\n",
|
||
" <td>0.366</td>\n",
|
||
" <td>22</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>451</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>134</td>\n",
|
||
" <td>70</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>28.9</td>\n",
|
||
" <td>0.542</td>\n",
|
||
" <td>23</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>113</th>\n",
|
||
" <td>4</td>\n",
|
||
" <td>76</td>\n",
|
||
" <td>62</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>34.0</td>\n",
|
||
" <td>0.391</td>\n",
|
||
" <td>25</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>556</th>\n",
|
||
" <td>1</td>\n",
|
||
" <td>97</td>\n",
|
||
" <td>70</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>38.1</td>\n",
|
||
" <td>0.218</td>\n",
|
||
" <td>30</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>667</th>\n",
|
||
" <td>10</td>\n",
|
||
" <td>111</td>\n",
|
||
" <td>70</td>\n",
|
||
" <td>27</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>27.5</td>\n",
|
||
" <td>0.141</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>107</th>\n",
|
||
" <td>4</td>\n",
|
||
" <td>144</td>\n",
|
||
" <td>58</td>\n",
|
||
" <td>28</td>\n",
|
||
" <td>140</td>\n",
|
||
" <td>29.5</td>\n",
|
||
" <td>0.287</td>\n",
|
||
" <td>37</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>614 rows × 9 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
||
"353 1 90 62 12 43 27.2 \n",
|
||
"711 5 126 78 27 22 29.6 \n",
|
||
"373 2 105 58 40 94 34.9 \n",
|
||
"46 1 146 56 0 0 29.7 \n",
|
||
"682 0 95 64 39 105 44.6 \n",
|
||
".. ... ... ... ... ... ... \n",
|
||
"451 2 134 70 0 0 28.9 \n",
|
||
"113 4 76 62 0 0 34.0 \n",
|
||
"556 1 97 70 40 0 38.1 \n",
|
||
"667 10 111 70 27 0 27.5 \n",
|
||
"107 4 144 58 28 140 29.5 \n",
|
||
"\n",
|
||
" DiabetesPedigreeFunction Age Outcome \n",
|
||
"353 0.580 24 0 \n",
|
||
"711 0.439 40 0 \n",
|
||
"373 0.225 25 0 \n",
|
||
"46 0.564 29 0 \n",
|
||
"682 0.366 22 0 \n",
|
||
".. ... ... ... \n",
|
||
"451 0.542 23 1 \n",
|
||
"113 0.391 25 0 \n",
|
||
"556 0.218 30 0 \n",
|
||
"667 0.141 40 1 \n",
|
||
"107 0.287 37 0 \n",
|
||
"\n",
|
||
"[614 rows x 9 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Outcome</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>353</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>711</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>373</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>46</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>682</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>451</th>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>113</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>556</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>667</th>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>107</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>614 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Outcome\n",
|
||
"353 0\n",
|
||
"711 0\n",
|
||
"373 0\n",
|
||
"46 0\n",
|
||
"682 0\n",
|
||
".. ...\n",
|
||
"451 1\n",
|
||
"113 0\n",
|
||
"556 0\n",
|
||
"667 1\n",
|
||
"107 0\n",
|
||
"\n",
|
||
"[614 rows x 1 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Pregnancies</th>\n",
|
||
" <th>Glucose</th>\n",
|
||
" <th>BloodPressure</th>\n",
|
||
" <th>SkinThickness</th>\n",
|
||
" <th>Insulin</th>\n",
|
||
" <th>BMI</th>\n",
|
||
" <th>DiabetesPedigreeFunction</th>\n",
|
||
" <th>Age</th>\n",
|
||
" <th>Outcome</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>44</th>\n",
|
||
" <td>7</td>\n",
|
||
" <td>159</td>\n",
|
||
" <td>64</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>27.4</td>\n",
|
||
" <td>0.294</td>\n",
|
||
" <td>40</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>672</th>\n",
|
||
" <td>10</td>\n",
|
||
" <td>68</td>\n",
|
||
" <td>106</td>\n",
|
||
" <td>23</td>\n",
|
||
" <td>49</td>\n",
|
||
" <td>35.5</td>\n",
|
||
" <td>0.285</td>\n",
|
||
" <td>47</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>700</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>122</td>\n",
|
||
" <td>76</td>\n",
|
||
" <td>27</td>\n",
|
||
" <td>200</td>\n",
|
||
" <td>35.9</td>\n",
|
||
" <td>0.483</td>\n",
|
||
" <td>26</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>630</th>\n",
|
||
" <td>7</td>\n",
|
||
" <td>114</td>\n",
|
||
" <td>64</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>27.4</td>\n",
|
||
" <td>0.732</td>\n",
|
||
" <td>34</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>81</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>74</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.102</td>\n",
|
||
" <td>22</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>32</th>\n",
|
||
" <td>3</td>\n",
|
||
" <td>88</td>\n",
|
||
" <td>58</td>\n",
|
||
" <td>11</td>\n",
|
||
" <td>54</td>\n",
|
||
" <td>24.8</td>\n",
|
||
" <td>0.267</td>\n",
|
||
" <td>22</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>637</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>94</td>\n",
|
||
" <td>76</td>\n",
|
||
" <td>18</td>\n",
|
||
" <td>66</td>\n",
|
||
" <td>31.6</td>\n",
|
||
" <td>0.649</td>\n",
|
||
" <td>23</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>593</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>82</td>\n",
|
||
" <td>52</td>\n",
|
||
" <td>22</td>\n",
|
||
" <td>115</td>\n",
|
||
" <td>28.5</td>\n",
|
||
" <td>1.699</td>\n",
|
||
" <td>25</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>425</th>\n",
|
||
" <td>4</td>\n",
|
||
" <td>184</td>\n",
|
||
" <td>78</td>\n",
|
||
" <td>39</td>\n",
|
||
" <td>277</td>\n",
|
||
" <td>37.0</td>\n",
|
||
" <td>0.264</td>\n",
|
||
" <td>31</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>273</th>\n",
|
||
" <td>1</td>\n",
|
||
" <td>71</td>\n",
|
||
" <td>78</td>\n",
|
||
" <td>50</td>\n",
|
||
" <td>45</td>\n",
|
||
" <td>33.2</td>\n",
|
||
" <td>0.422</td>\n",
|
||
" <td>21</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>154 rows × 9 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
||
"44 7 159 64 0 0 27.4 \n",
|
||
"672 10 68 106 23 49 35.5 \n",
|
||
"700 2 122 76 27 200 35.9 \n",
|
||
"630 7 114 64 0 0 27.4 \n",
|
||
"81 2 74 0 0 0 0.0 \n",
|
||
".. ... ... ... ... ... ... \n",
|
||
"32 3 88 58 11 54 24.8 \n",
|
||
"637 2 94 76 18 66 31.6 \n",
|
||
"593 2 82 52 22 115 28.5 \n",
|
||
"425 4 184 78 39 277 37.0 \n",
|
||
"273 1 71 78 50 45 33.2 \n",
|
||
"\n",
|
||
" DiabetesPedigreeFunction Age Outcome \n",
|
||
"44 0.294 40 0 \n",
|
||
"672 0.285 47 0 \n",
|
||
"700 0.483 26 0 \n",
|
||
"630 0.732 34 1 \n",
|
||
"81 0.102 22 0 \n",
|
||
".. ... ... ... \n",
|
||
"32 0.267 22 0 \n",
|
||
"637 0.649 23 0 \n",
|
||
"593 1.699 25 0 \n",
|
||
"425 0.264 31 1 \n",
|
||
"273 0.422 21 0 \n",
|
||
"\n",
|
||
"[154 rows x 9 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Outcome</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>44</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>672</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>700</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>630</th>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>81</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>32</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>637</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>593</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>425</th>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>273</th>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>154 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Outcome\n",
|
||
"44 0\n",
|
||
"672 0\n",
|
||
"700 0\n",
|
||
"630 1\n",
|
||
"81 0\n",
|
||
".. ...\n",
|
||
"32 0\n",
|
||
"637 0\n",
|
||
"593 0\n",
|
||
"425 1\n",
|
||
"273 0\n",
|
||
"\n",
|
||
"[154 rows x 1 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from typing import Tuple\n",
|
||
"import pandas as pd\n",
|
||
"from pandas import DataFrame\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"\n",
|
||
"random_state = 42\n",
|
||
"\n",
|
||
"def split_stratified_into_train_val_test(\n",
|
||
" df_input,\n",
|
||
" stratify_colname=\"y\",\n",
|
||
" frac_train=0.6,\n",
|
||
" frac_val=0.15,\n",
|
||
" frac_test=0.25,\n",
|
||
" random_state=None,\n",
|
||
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
||
" \n",
|
||
" if frac_train + frac_val + frac_test != 1.0:\n",
|
||
" raise ValueError(\n",
|
||
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
|
||
" % (frac_train, frac_val, frac_test)\n",
|
||
" )\n",
|
||
" if stratify_colname not in df_input.columns:\n",
|
||
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
|
||
" X = df_input\n",
|
||
" y = df_input[\n",
|
||
" [stratify_colname]\n",
|
||
" ]\n",
|
||
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
|
||
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
|
||
" )\n",
|
||
" if frac_val <= 0:\n",
|
||
" assert len(df_input) == len(df_train) + len(df_temp)\n",
|
||
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
|
||
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
|
||
" df_val, df_test, y_val, y_test = train_test_split(\n",
|
||
" df_temp,\n",
|
||
" y_temp,\n",
|
||
" stratify=y_temp,\n",
|
||
" test_size=relative_frac_test,\n",
|
||
" random_state=random_state,\n",
|
||
" )\n",
|
||
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
|
||
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
|
||
"\n",
|
||
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
|
||
" df, stratify_colname=\"Outcome\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
|
||
")\n",
|
||
"\n",
|
||
"display(\"X_train\", X_train)\n",
|
||
"display(\"y_train\", y_train)\n",
|
||
"\n",
|
||
"display(\"X_test\", X_test)\n",
|
||
"display(\"y_test\", y_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Формирование конвейера для классификации данных\n",
|
||
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
|
||
"\n",
|
||
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
|
||
"\n",
|
||
"features_preprocessing -- трансформер для предобработки признаков\n",
|
||
"\n",
|
||
"features_engineering -- трансформер для конструирования признаков\n",
|
||
"\n",
|
||
"drop_columns -- трансформер для удаления колонок\n",
|
||
"\n",
|
||
"features_postprocessing -- трансформер для унитарного кодирования новых признаков\n",
|
||
"\n",
|
||
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков\n",
|
||
"\n",
|
||
"Конвейер выполняется последовательно.\n",
|
||
"\n",
|
||
"Трансформер выполняет параллельно для указанного набора колонок."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"from sklearn.base import BaseEstimator, TransformerMixin\n",
|
||
"from sklearn.compose import ColumnTransformer\n",
|
||
"from sklearn.impute import SimpleImputer\n",
|
||
"from sklearn.preprocessing import OneHotEncoder, StandardScaler\n",
|
||
"from sklearn.pipeline import Pipeline\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"\n",
|
||
"# Построение конвейеров предобработки\n",
|
||
"class DiabetesFeatures(BaseEstimator, TransformerMixin):\n",
|
||
" def __init__(self):\n",
|
||
" pass\n",
|
||
"\n",
|
||
" def fit(self, X, y=None):\n",
|
||
" return self\n",
|
||
"\n",
|
||
" def transform(self, X, y=None):\n",
|
||
" # Добавим признак отношения индекса массы тела и возраста.\n",
|
||
" X = X.copy()\n",
|
||
" X[\"Glucose_Insulin\"] = X[\"Glucose\"] / X[\"Insulin\"]\n",
|
||
" return X\n",
|
||
"\n",
|
||
" def get_feature_names_out(self, features_in):\n",
|
||
" new_features = [\"Glucose_Insulin\"]\n",
|
||
" return np.append(features_in, new_features, axis=0)\n",
|
||
"\n",
|
||
"# Обработка числовых данных. Числовой конвейр: заполнение пропущенных значений медианой и стандартизация\n",
|
||
"preprocessing_num_class = Pipeline(steps=[\n",
|
||
" ('imputer', SimpleImputer(strategy='median')),\n",
|
||
" ('scaler', StandardScaler())\n",
|
||
"])\n",
|
||
"\n",
|
||
"preprocessing_cat_class = Pipeline(steps=[\n",
|
||
" ('imputer', SimpleImputer(strategy='most_frequent')),\n",
|
||
" ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False, drop='first'))\n",
|
||
"])\n",
|
||
"\n",
|
||
"columns_to_drop = []\n",
|
||
"numeric_columns = [\"Pregnancies\", \"Glucose\", \"BloodPressure\", \"SkinThickness\", \"Insulin\",\n",
|
||
" \"BMI\", \"DiabetesPedigreeFunction\", \"Age\"]\n",
|
||
"cat_columns = [\"Outcome\"]\n",
|
||
"\n",
|
||
"features_preprocessing = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"preprocessing_num\", preprocessing_num_class, numeric_columns),\n",
|
||
" (\"preprocessing_cat\", preprocessing_cat_class, cat_columns),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\"\n",
|
||
")\n",
|
||
"\n",
|
||
"drop_columns = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\",\n",
|
||
")\n",
|
||
"\n",
|
||
"features_postprocessing = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" ('preprocessing_cat', preprocessing_cat_class, [\"Outcome\"]),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\",\n",
|
||
")\n",
|
||
"\n",
|
||
"pipeline_end = Pipeline(\n",
|
||
" [\n",
|
||
" (\"features_preprocessing\", features_preprocessing),\n",
|
||
" (\"custom_features\", DiabetesFeatures()),\n",
|
||
" (\"drop_columns\", drop_columns),\n",
|
||
" ]\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Работа конвейера:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Pregnancies</th>\n",
|
||
" <th>Glucose</th>\n",
|
||
" <th>BloodPressure</th>\n",
|
||
" <th>SkinThickness</th>\n",
|
||
" <th>Insulin</th>\n",
|
||
" <th>BMI</th>\n",
|
||
" <th>DiabetesPedigreeFunction</th>\n",
|
||
" <th>Age</th>\n",
|
||
" <th>Outcome_1</th>\n",
|
||
" <th>Glucose_Insulin</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>353</th>\n",
|
||
" <td>-0.851355</td>\n",
|
||
" <td>-0.980131</td>\n",
|
||
" <td>-0.404784</td>\n",
|
||
" <td>-0.553973</td>\n",
|
||
" <td>-0.331319</td>\n",
|
||
" <td>-0.607678</td>\n",
|
||
" <td>0.310794</td>\n",
|
||
" <td>-0.792169</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>2.958266</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>711</th>\n",
|
||
" <td>0.356576</td>\n",
|
||
" <td>0.161444</td>\n",
|
||
" <td>0.465368</td>\n",
|
||
" <td>0.392787</td>\n",
|
||
" <td>-0.526398</td>\n",
|
||
" <td>-0.302139</td>\n",
|
||
" <td>-0.116439</td>\n",
|
||
" <td>0.561034</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>-0.306696</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>373</th>\n",
|
||
" <td>-0.549372</td>\n",
|
||
" <td>-0.504474</td>\n",
|
||
" <td>-0.622322</td>\n",
|
||
" <td>1.213312</td>\n",
|
||
" <td>0.142444</td>\n",
|
||
" <td>0.372594</td>\n",
|
||
" <td>-0.764862</td>\n",
|
||
" <td>-0.707594</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>-3.541575</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>46</th>\n",
|
||
" <td>-0.851355</td>\n",
|
||
" <td>0.795653</td>\n",
|
||
" <td>-0.731091</td>\n",
|
||
" <td>-1.311380</td>\n",
|
||
" <td>-0.730766</td>\n",
|
||
" <td>-0.289408</td>\n",
|
||
" <td>0.262314</td>\n",
|
||
" <td>-0.369293</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>-1.088792</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>682</th>\n",
|
||
" <td>-1.153338</td>\n",
|
||
" <td>-0.821579</td>\n",
|
||
" <td>-0.296015</td>\n",
|
||
" <td>1.150195</td>\n",
|
||
" <td>0.244628</td>\n",
|
||
" <td>1.607482</td>\n",
|
||
" <td>-0.337630</td>\n",
|
||
" <td>-0.961320</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>-3.358486</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>451</th>\n",
|
||
" <td>-0.549372</td>\n",
|
||
" <td>0.415128</td>\n",
|
||
" <td>0.030292</td>\n",
|
||
" <td>-1.311380</td>\n",
|
||
" <td>-0.730766</td>\n",
|
||
" <td>-0.391255</td>\n",
|
||
" <td>0.195653</td>\n",
|
||
" <td>-0.876744</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>-0.568071</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>113</th>\n",
|
||
" <td>0.054593</td>\n",
|
||
" <td>-1.424076</td>\n",
|
||
" <td>-0.404784</td>\n",
|
||
" <td>-1.311380</td>\n",
|
||
" <td>-0.730766</td>\n",
|
||
" <td>0.258017</td>\n",
|
||
" <td>-0.261879</td>\n",
|
||
" <td>-0.707594</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.948744</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>556</th>\n",
|
||
" <td>-0.851355</td>\n",
|
||
" <td>-0.758158</td>\n",
|
||
" <td>0.030292</td>\n",
|
||
" <td>1.213312</td>\n",
|
||
" <td>-0.730766</td>\n",
|
||
" <td>0.779980</td>\n",
|
||
" <td>-0.786072</td>\n",
|
||
" <td>-0.284718</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.037483</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>667</th>\n",
|
||
" <td>1.866489</td>\n",
|
||
" <td>-0.314212</td>\n",
|
||
" <td>0.030292</td>\n",
|
||
" <td>0.392787</td>\n",
|
||
" <td>-0.730766</td>\n",
|
||
" <td>-0.569486</td>\n",
|
||
" <td>-1.019383</td>\n",
|
||
" <td>0.561034</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.429976</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>107</th>\n",
|
||
" <td>0.054593</td>\n",
|
||
" <td>0.732232</td>\n",
|
||
" <td>-0.622322</td>\n",
|
||
" <td>0.455904</td>\n",
|
||
" <td>0.569759</td>\n",
|
||
" <td>-0.314870</td>\n",
|
||
" <td>-0.577001</td>\n",
|
||
" <td>0.307308</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.285160</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>614 rows × 10 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
||
"353 -0.851355 -0.980131 -0.404784 -0.553973 -0.331319 -0.607678 \n",
|
||
"711 0.356576 0.161444 0.465368 0.392787 -0.526398 -0.302139 \n",
|
||
"373 -0.549372 -0.504474 -0.622322 1.213312 0.142444 0.372594 \n",
|
||
"46 -0.851355 0.795653 -0.731091 -1.311380 -0.730766 -0.289408 \n",
|
||
"682 -1.153338 -0.821579 -0.296015 1.150195 0.244628 1.607482 \n",
|
||
".. ... ... ... ... ... ... \n",
|
||
"451 -0.549372 0.415128 0.030292 -1.311380 -0.730766 -0.391255 \n",
|
||
"113 0.054593 -1.424076 -0.404784 -1.311380 -0.730766 0.258017 \n",
|
||
"556 -0.851355 -0.758158 0.030292 1.213312 -0.730766 0.779980 \n",
|
||
"667 1.866489 -0.314212 0.030292 0.392787 -0.730766 -0.569486 \n",
|
||
"107 0.054593 0.732232 -0.622322 0.455904 0.569759 -0.314870 \n",
|
||
"\n",
|
||
" DiabetesPedigreeFunction Age Outcome_1 Glucose_Insulin \n",
|
||
"353 0.310794 -0.792169 0.0 2.958266 \n",
|
||
"711 -0.116439 0.561034 0.0 -0.306696 \n",
|
||
"373 -0.764862 -0.707594 0.0 -3.541575 \n",
|
||
"46 0.262314 -0.369293 0.0 -1.088792 \n",
|
||
"682 -0.337630 -0.961320 0.0 -3.358486 \n",
|
||
".. ... ... ... ... \n",
|
||
"451 0.195653 -0.876744 1.0 -0.568071 \n",
|
||
"113 -0.261879 -0.707594 0.0 1.948744 \n",
|
||
"556 -0.786072 -0.284718 0.0 1.037483 \n",
|
||
"667 -1.019383 0.561034 1.0 0.429976 \n",
|
||
"107 -0.577001 0.307308 0.0 1.285160 \n",
|
||
"\n",
|
||
"[614 rows x 10 columns]"
|
||
]
|
||
},
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
|
||
"preprocessed_df = pd.DataFrame(\n",
|
||
" preprocessing_result,\n",
|
||
" columns=pipeline_end.get_feature_names_out(),\n",
|
||
")\n",
|
||
"preprocessed_df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Формирование набора моделей для классификации\n",
|
||
"logistic -- логистическая регрессия\n",
|
||
"\n",
|
||
"ridge -- гребневая регрессия\n",
|
||
"\n",
|
||
"decision_tree -- дерево решений\n",
|
||
"\n",
|
||
"knn -- k-ближайших соседей\n",
|
||
"\n",
|
||
"naive_bayes -- наивный Байесовский классификатор\n",
|
||
"\n",
|
||
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
|
||
"\n",
|
||
"random_forest -- метод случайного леса (набор деревьев решений)\n",
|
||
"\n",
|
||
"mlp -- многослойный персептрон (нейронная сеть)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
|
||
"\n",
|
||
"random_state = 42\n",
|
||
"\n",
|
||
"class_models = {\n",
|
||
" \"logistic\": {\"model\": linear_model.LogisticRegression(random_state=random_state)},\n",
|
||
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\", random_state=random_state)},\n",
|
||
" \"decision_tree\": {\n",
|
||
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
|
||
" },\n",
|
||
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
|
||
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
|
||
" \"gradient_boosting\": {\n",
|
||
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210, random_state=random_state)\n",
|
||
" },\n",
|
||
" \"random_forest\": {\n",
|
||
" \"model\": ensemble.RandomForestClassifier(\n",
|
||
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
|
||
" )\n",
|
||
" },\n",
|
||
" \"mlp\": {\n",
|
||
" \"model\": neural_network.MLPClassifier(\n",
|
||
" hidden_layer_sizes=(7,),\n",
|
||
" max_iter=500,\n",
|
||
" early_stopping=True,\n",
|
||
" random_state=random_state,\n",
|
||
" )\n",
|
||
" },\n",
|
||
"}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Обучение моделей на обучающем наборе данных и оценка на тестовом"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Model: logistic\n",
|
||
"Model: ridge\n",
|
||
"Model: decision_tree\n",
|
||
"Model: knn\n",
|
||
"Model: naive_bayes\n",
|
||
"Model: gradient_boosting\n",
|
||
"Model: random_forest\n",
|
||
"Model: mlp\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"from sklearn import metrics\n",
|
||
"\n",
|
||
"for model_name in class_models.keys():\n",
|
||
" print(f\"Model: {model_name}\")\n",
|
||
" model = class_models[model_name][\"model\"]\n",
|
||
"\n",
|
||
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
|
||
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
|
||
"\n",
|
||
" y_train_predict = model_pipeline.predict(X_train)\n",
|
||
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
|
||
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
|
||
"\n",
|
||
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
||
" class_models[model_name][\"probs\"] = y_test_probs\n",
|
||
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
||
"\n",
|
||
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
|
||
" y_test, y_test_probs\n",
|
||
" )\n",
|
||
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
|
||
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
|
||
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Сводная таблица оценок качества для использованных моделей классификации"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 1200x1000 with 16 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"\n",
|
||
"n_rows = int(len(class_models) / 2)\n",
|
||
"n_cols = 2\n",
|
||
"\n",
|
||
"fig, ax = plt.subplots(n_rows, n_cols, figsize=(12, 10), sharex=False, sharey=False)\n",
|
||
"\n",
|
||
"for index, key in enumerate(class_models.keys()):\n",
|
||
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
|
||
" disp = ConfusionMatrixDisplay(\n",
|
||
" confusion_matrix=c_matrix, display_labels=[\"No Diabetes\", \"Diabetes\"]\n",
|
||
" ).plot(ax=ax.flat[index])\n",
|
||
" disp.ax_.set_title(key)\n",
|
||
"\n",
|
||
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Значение 100 - это количество вернных диагнозов (True Positives), там модель верно определила людей у которых нет диабета \"No Diabetes\".\n",
|
||
"\n",
|
||
"Значение 54 у некоторых моделей - это количество неверных диагнозов (False Negatives), там модель неверно определила людей с диабетом, те, у кого нет диабета \"No Diabetes\" были отнесены к классу \"Diabetes\".\n",
|
||
"\n",
|
||
"Исходя из истинных и ложных значений (True Positives и False Negatives), можно сделать вывод, что модель имеет высокую точность при предсказании людей без диабета \"No Diabetes\". Уровень ложных результатов в некотоорых моделях со значением 54 говорит о том, что есть такие данные, которые модель пропускает.\n",
|
||
"\n",
|
||
"Точность, полнота, верность (аккуратность), F-мера"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_fed29_row0_col0, #T_fed29_row0_col1, #T_fed29_row0_col2, #T_fed29_row0_col3, #T_fed29_row1_col0, #T_fed29_row1_col1, #T_fed29_row1_col2, #T_fed29_row1_col3, #T_fed29_row2_col0, #T_fed29_row2_col1, #T_fed29_row2_col2, #T_fed29_row2_col3, #T_fed29_row3_col0, #T_fed29_row3_col1, #T_fed29_row3_col2, #T_fed29_row3_col3, #T_fed29_row4_col0, #T_fed29_row4_col1, #T_fed29_row4_col2, #T_fed29_row4_col3, #T_fed29_row5_col0, #T_fed29_row5_col1, #T_fed29_row5_col2, #T_fed29_row5_col3 {\n",
|
||
" background-color: #a8db34;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_fed29_row0_col4, #T_fed29_row0_col5, #T_fed29_row0_col6, #T_fed29_row0_col7, #T_fed29_row1_col4, #T_fed29_row1_col5, #T_fed29_row1_col6, #T_fed29_row1_col7, #T_fed29_row2_col4, #T_fed29_row2_col5, #T_fed29_row2_col6, #T_fed29_row2_col7, #T_fed29_row3_col4, #T_fed29_row3_col5, #T_fed29_row3_col6, #T_fed29_row3_col7, #T_fed29_row4_col4, #T_fed29_row4_col5, #T_fed29_row4_col6, #T_fed29_row4_col7, #T_fed29_row5_col4, #T_fed29_row5_col5, #T_fed29_row5_col6, #T_fed29_row5_col7 {\n",
|
||
" background-color: #da5a6a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_fed29_row6_col0 {\n",
|
||
" background-color: #86d549;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_fed29_row6_col1 {\n",
|
||
" background-color: #4ac16d;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_fed29_row6_col2 {\n",
|
||
" background-color: #75d054;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_fed29_row6_col3 {\n",
|
||
" background-color: #69cd5b;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_fed29_row6_col4 {\n",
|
||
" background-color: #c43e7f;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_fed29_row6_col5 {\n",
|
||
" background-color: #ae2892;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_fed29_row6_col6 {\n",
|
||
" background-color: #cc4977;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_fed29_row6_col7 {\n",
|
||
" background-color: #c13b82;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_fed29_row7_col0, #T_fed29_row7_col1, #T_fed29_row7_col2, #T_fed29_row7_col3 {\n",
|
||
" background-color: #26818e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_fed29_row7_col4, #T_fed29_row7_col5, #T_fed29_row7_col6, #T_fed29_row7_col7 {\n",
|
||
" background-color: #4e02a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_fed29\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_fed29_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
||
" <th id=\"T_fed29_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
||
" <th id=\"T_fed29_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
||
" <th id=\"T_fed29_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
||
" <th id=\"T_fed29_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
||
" <th id=\"T_fed29_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_fed29_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
||
" <th id=\"T_fed29_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_fed29_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
|
||
" <td id=\"T_fed29_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_fed29_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
|
||
" <td id=\"T_fed29_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_fed29_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
|
||
" <td id=\"T_fed29_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_fed29_level0_row3\" class=\"row_heading level0 row3\" >naive_bayes</th>\n",
|
||
" <td id=\"T_fed29_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row3_col5\" class=\"data row3 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row3_col6\" class=\"data row3 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row3_col7\" class=\"data row3 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_fed29_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
|
||
" <td id=\"T_fed29_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row4_col5\" class=\"data row4 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row4_col6\" class=\"data row4 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row4_col7\" class=\"data row4 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_fed29_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
|
||
" <td id=\"T_fed29_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row5_col5\" class=\"data row5 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row5_col6\" class=\"data row5 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_fed29_row5_col7\" class=\"data row5 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_fed29_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
|
||
" <td id=\"T_fed29_row6_col0\" class=\"data row6 col0\" >0.947090</td>\n",
|
||
" <td id=\"T_fed29_row6_col1\" class=\"data row6 col1\" >0.796296</td>\n",
|
||
" <td id=\"T_fed29_row6_col2\" class=\"data row6 col2\" >0.836449</td>\n",
|
||
" <td id=\"T_fed29_row6_col3\" class=\"data row6 col3\" >0.796296</td>\n",
|
||
" <td id=\"T_fed29_row6_col4\" class=\"data row6 col4\" >0.926710</td>\n",
|
||
" <td id=\"T_fed29_row6_col5\" class=\"data row6 col5\" >0.857143</td>\n",
|
||
" <td id=\"T_fed29_row6_col6\" class=\"data row6 col6\" >0.888337</td>\n",
|
||
" <td id=\"T_fed29_row6_col7\" class=\"data row6 col7\" >0.796296</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_fed29_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
|
||
" <td id=\"T_fed29_row7_col0\" class=\"data row7 col0\" >0.552632</td>\n",
|
||
" <td id=\"T_fed29_row7_col1\" class=\"data row7 col1\" >0.428571</td>\n",
|
||
" <td id=\"T_fed29_row7_col2\" class=\"data row7 col2\" >0.098131</td>\n",
|
||
" <td id=\"T_fed29_row7_col3\" class=\"data row7 col3\" >0.111111</td>\n",
|
||
" <td id=\"T_fed29_row7_col4\" class=\"data row7 col4\" >0.657980</td>\n",
|
||
" <td id=\"T_fed29_row7_col5\" class=\"data row7 col5\" >0.636364</td>\n",
|
||
" <td id=\"T_fed29_row7_col6\" class=\"data row7 col6\" >0.166667</td>\n",
|
||
" <td id=\"T_fed29_row7_col7\" class=\"data row7 col7\" >0.176471</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x135c39b70b0>"
|
||
]
|
||
},
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
||
" [\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" \"Accuracy_train\",\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_train\",\n",
|
||
" \"F1_test\",\n",
|
||
" ]\n",
|
||
"]\n",
|
||
"class_metrics.sort_values(\n",
|
||
" by=\"Accuracy_test\", ascending=False\n",
|
||
").style.background_gradient(\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Почти все модели в данной выборке, а именно логистическая регрессия, ридж-регрессия, дерево решений, наивный байесовский классификатор, случайный лес, градиентный бустинг, KNN демонстрируют неплохие значения по всем метрикам на обучающих и тестовых наборах данных. Модель MLP не так эффективна по сравнению с другими, но в некоторых метриках показывает неплохие результаты."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_2bffc_row0_col0, #T_2bffc_row0_col1, #T_2bffc_row1_col0, #T_2bffc_row1_col1, #T_2bffc_row2_col0, #T_2bffc_row2_col1, #T_2bffc_row3_col0, #T_2bffc_row3_col1, #T_2bffc_row4_col0, #T_2bffc_row4_col1, #T_2bffc_row5_col0, #T_2bffc_row5_col1 {\n",
|
||
" background-color: #a8db34;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_2bffc_row0_col2, #T_2bffc_row0_col3, #T_2bffc_row0_col4, #T_2bffc_row1_col2, #T_2bffc_row1_col3, #T_2bffc_row1_col4, #T_2bffc_row2_col2, #T_2bffc_row2_col3, #T_2bffc_row2_col4, #T_2bffc_row3_col2, #T_2bffc_row3_col3, #T_2bffc_row3_col4, #T_2bffc_row4_col2, #T_2bffc_row4_col3, #T_2bffc_row4_col4, #T_2bffc_row5_col2, #T_2bffc_row5_col3, #T_2bffc_row5_col4 {\n",
|
||
" background-color: #da5a6a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_2bffc_row6_col0 {\n",
|
||
" background-color: #42be71;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_2bffc_row6_col1 {\n",
|
||
" background-color: #65cb5e;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_2bffc_row6_col2 {\n",
|
||
" background-color: #c5407e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_2bffc_row6_col3 {\n",
|
||
" background-color: #b7318a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_2bffc_row6_col4 {\n",
|
||
" background-color: #b6308b;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_2bffc_row7_col0, #T_2bffc_row7_col1 {\n",
|
||
" background-color: #26818e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_2bffc_row7_col2, #T_2bffc_row7_col3, #T_2bffc_row7_col4 {\n",
|
||
" background-color: #4e02a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_2bffc\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_2bffc_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_2bffc_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
||
" <th id=\"T_2bffc_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
||
" <th id=\"T_2bffc_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
||
" <th id=\"T_2bffc_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_2bffc_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
|
||
" <td id=\"T_2bffc_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_2bffc_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
|
||
" <td id=\"T_2bffc_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_2bffc_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
|
||
" <td id=\"T_2bffc_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_2bffc_level0_row3\" class=\"row_heading level0 row3\" >naive_bayes</th>\n",
|
||
" <td id=\"T_2bffc_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_2bffc_level0_row4\" class=\"row_heading level0 row4\" >gradient_boosting</th>\n",
|
||
" <td id=\"T_2bffc_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_2bffc_level0_row5\" class=\"row_heading level0 row5\" >random_forest</th>\n",
|
||
" <td id=\"T_2bffc_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_2bffc_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_2bffc_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
|
||
" <td id=\"T_2bffc_row6_col0\" class=\"data row6 col0\" >0.857143</td>\n",
|
||
" <td id=\"T_2bffc_row6_col1\" class=\"data row6 col1\" >0.796296</td>\n",
|
||
" <td id=\"T_2bffc_row6_col2\" class=\"data row6 col2\" >0.923333</td>\n",
|
||
" <td id=\"T_2bffc_row6_col3\" class=\"data row6 col3\" >0.686296</td>\n",
|
||
" <td id=\"T_2bffc_row6_col4\" class=\"data row6 col4\" >0.686296</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_2bffc_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
|
||
" <td id=\"T_2bffc_row7_col0\" class=\"data row7 col0\" >0.636364</td>\n",
|
||
" <td id=\"T_2bffc_row7_col1\" class=\"data row7 col1\" >0.176471</td>\n",
|
||
" <td id=\"T_2bffc_row7_col2\" class=\"data row7 col2\" >0.630556</td>\n",
|
||
" <td id=\"T_2bffc_row7_col3\" class=\"data row7 col3\" >0.037500</td>\n",
|
||
" <td id=\"T_2bffc_row7_col4\" class=\"data row7 col4\" >0.051640</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x135c3bf80b0>"
|
||
]
|
||
},
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
||
" [\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_test\",\n",
|
||
" \"ROC_AUC_test\",\n",
|
||
" \"Cohen_kappa_test\",\n",
|
||
" \"MCC_test\",\n",
|
||
" ]\n",
|
||
"]\n",
|
||
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" subset=[\n",
|
||
" \"ROC_AUC_test\",\n",
|
||
" \"MCC_test\",\n",
|
||
" \"Cohen_kappa_test\",\n",
|
||
" ],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Такой же вывод можно сделать и для следующих метрик: Accuracy, F1, ROC AUC, Cohen's Kappa и MCC. Все модели, кроме KNN и MLP, указывают на хорошо-развитую способность к выделению классов"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'logistic'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
|
||
"\n",
|
||
"display(best_model)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Вывод данных с ошибкой предсказания для оценки"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'Error items count: 0'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Pregnancies</th>\n",
|
||
" <th>Predicted</th>\n",
|
||
" <th>Glucose</th>\n",
|
||
" <th>BloodPressure</th>\n",
|
||
" <th>SkinThickness</th>\n",
|
||
" <th>Insulin</th>\n",
|
||
" <th>BMI</th>\n",
|
||
" <th>DiabetesPedigreeFunction</th>\n",
|
||
" <th>Age</th>\n",
|
||
" <th>Outcome</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
"Empty DataFrame\n",
|
||
"Columns: [Pregnancies, Predicted, Glucose, BloodPressure, SkinThickness, Insulin, BMI, DiabetesPedigreeFunction, Age, Outcome]\n",
|
||
"Index: []"
|
||
]
|
||
},
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"preprocessing_result = pipeline_end.transform(X_test)\n",
|
||
"preprocessed_df = pd.DataFrame(\n",
|
||
" preprocessing_result,\n",
|
||
" columns=pipeline_end.get_feature_names_out(),\n",
|
||
")\n",
|
||
"\n",
|
||
"y_pred = class_models[best_model][\"preds\"]\n",
|
||
"\n",
|
||
"error_index = y_test[y_test[\"Outcome\"] != y_pred].index.tolist()\n",
|
||
"display(f\"Error items count: {len(error_index)}\")\n",
|
||
"\n",
|
||
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
|
||
"error_df = X_test.loc[error_index].copy()\n",
|
||
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
|
||
"error_df.sort_index()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Пример использования обученной модели (конвейера) для предсказания"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Pregnancies</th>\n",
|
||
" <th>Glucose</th>\n",
|
||
" <th>BloodPressure</th>\n",
|
||
" <th>SkinThickness</th>\n",
|
||
" <th>Insulin</th>\n",
|
||
" <th>BMI</th>\n",
|
||
" <th>DiabetesPedigreeFunction</th>\n",
|
||
" <th>Age</th>\n",
|
||
" <th>Outcome</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>190</th>\n",
|
||
" <td>3.0</td>\n",
|
||
" <td>111.0</td>\n",
|
||
" <td>62.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>22.6</td>\n",
|
||
" <td>0.142</td>\n",
|
||
" <td>21.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
||
"190 3.0 111.0 62.0 0.0 0.0 22.6 \n",
|
||
"\n",
|
||
" DiabetesPedigreeFunction Age Outcome \n",
|
||
"190 0.142 21.0 0.0 "
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Pregnancies</th>\n",
|
||
" <th>Glucose</th>\n",
|
||
" <th>BloodPressure</th>\n",
|
||
" <th>SkinThickness</th>\n",
|
||
" <th>Insulin</th>\n",
|
||
" <th>BMI</th>\n",
|
||
" <th>DiabetesPedigreeFunction</th>\n",
|
||
" <th>Age</th>\n",
|
||
" <th>Outcome_1</th>\n",
|
||
" <th>Glucose_Insulin</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>190</th>\n",
|
||
" <td>-0.24739</td>\n",
|
||
" <td>-0.314212</td>\n",
|
||
" <td>-0.404784</td>\n",
|
||
" <td>-1.31138</td>\n",
|
||
" <td>-0.730766</td>\n",
|
||
" <td>-1.193296</td>\n",
|
||
" <td>-1.016353</td>\n",
|
||
" <td>-1.045895</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.429976</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
||
"190 -0.24739 -0.314212 -0.404784 -1.31138 -0.730766 -1.193296 \n",
|
||
"\n",
|
||
" DiabetesPedigreeFunction Age Outcome_1 Glucose_Insulin \n",
|
||
"190 -1.016353 -1.045895 0.0 0.429976 "
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'predicted: 0 (proba: [0.99177252 0.00822748])'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'real: 0'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"model = class_models[best_model][\"pipeline\"]\n",
|
||
"\n",
|
||
"example_id = 190\n",
|
||
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
|
||
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
|
||
"display(test)\n",
|
||
"display(test_preprocessed)\n",
|
||
"result_proba = model.predict_proba(test)[0]\n",
|
||
"result = model.predict(test)[0]\n",
|
||
"real = int(y_test.loc[example_id].values[0])\n",
|
||
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
|
||
"display(f\"real: {real}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Подбор гиперпараметров методом поиска по сетке"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"c:\\Users\\tabee\\AIM_PIbd-31_Tabeev_A.P\\.venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
|
||
" _data = np.array(data, dtype=dtype, copy=copy,\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'model__criterion': 'gini',\n",
|
||
" 'model__max_depth': 5,\n",
|
||
" 'model__max_features': 'sqrt',\n",
|
||
" 'model__n_estimators': 10}"
|
||
]
|
||
},
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.model_selection import GridSearchCV\n",
|
||
"\n",
|
||
"optimized_model_type = \"random_forest\"\n",
|
||
"\n",
|
||
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
|
||
"\n",
|
||
"param_grid = {\n",
|
||
" \"model__n_estimators\": [10, 50, 100],\n",
|
||
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
|
||
" \"model__max_depth\": [5, 7, 10],\n",
|
||
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
|
||
"}\n",
|
||
"\n",
|
||
"gs_optomizer = GridSearchCV(\n",
|
||
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
|
||
")\n",
|
||
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
|
||
"gs_optomizer.best_params_"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Обучение модели с новыми гиперпараметрами"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.pipeline import Pipeline\n",
|
||
"from sklearn.preprocessing import StandardScaler\n",
|
||
"from sklearn.compose import ColumnTransformer\n",
|
||
"from sklearn.ensemble import RandomForestClassifier\n",
|
||
"import numpy as np\n",
|
||
"from sklearn import metrics\n",
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"numeric_features = X_train.select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
|
||
"\n",
|
||
"random_state = 42\n",
|
||
"# Определение трансформера\n",
|
||
"pipeline_end = ColumnTransformer([\n",
|
||
" ('numeric', StandardScaler(), numeric_features),\n",
|
||
"])\n",
|
||
"\n",
|
||
"optimized_model = RandomForestClassifier(\n",
|
||
" random_state=random_state,\n",
|
||
" criterion=\"gini\",\n",
|
||
" max_depth=5,\n",
|
||
" max_features=\"sqrt\",\n",
|
||
" n_estimators=50,\n",
|
||
")\n",
|
||
"\n",
|
||
"result = {}\n",
|
||
"\n",
|
||
"# Обучение модели\n",
|
||
"result[\"pipeline\"] = Pipeline([\n",
|
||
" (\"pipeline\", pipeline_end),\n",
|
||
" (\"model\", optimized_model)\n",
|
||
"]).fit(X_train, y_train.values.ravel())\n",
|
||
"\n",
|
||
"# Прогнозирование и расчет метрик\n",
|
||
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
|
||
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
|
||
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
|
||
"\n",
|
||
"# Метрики для оценки модели\n",
|
||
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
|
||
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
|
||
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
|
||
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
|
||
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
|
||
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
|
||
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
|
||
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
|
||
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
|
||
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
|
||
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
|
||
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Формирование данных для оценки старой и новой версии модели"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
|
||
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
||
" data=class_models[optimized_model_type]\n",
|
||
")\n",
|
||
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
||
" data=result\n",
|
||
")\n",
|
||
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
|
||
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Оценка параметров старой и новой модели"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_02d85_row0_col0, #T_02d85_row0_col1, #T_02d85_row0_col2, #T_02d85_row0_col3, #T_02d85_row1_col0, #T_02d85_row1_col1, #T_02d85_row1_col2, #T_02d85_row1_col3 {\n",
|
||
" background-color: #440154;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_02d85_row0_col4, #T_02d85_row0_col5, #T_02d85_row0_col6, #T_02d85_row0_col7, #T_02d85_row1_col4, #T_02d85_row1_col5, #T_02d85_row1_col6, #T_02d85_row1_col7 {\n",
|
||
" background-color: #0d0887;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_02d85\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_02d85_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
||
" <th id=\"T_02d85_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
||
" <th id=\"T_02d85_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
||
" <th id=\"T_02d85_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
||
" <th id=\"T_02d85_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
||
" <th id=\"T_02d85_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_02d85_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
||
" <th id=\"T_02d85_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th class=\"index_name level0\" >Name</th>\n",
|
||
" <th class=\"blank col0\" > </th>\n",
|
||
" <th class=\"blank col1\" > </th>\n",
|
||
" <th class=\"blank col2\" > </th>\n",
|
||
" <th class=\"blank col3\" > </th>\n",
|
||
" <th class=\"blank col4\" > </th>\n",
|
||
" <th class=\"blank col5\" > </th>\n",
|
||
" <th class=\"blank col6\" > </th>\n",
|
||
" <th class=\"blank col7\" > </th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_02d85_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
||
" <td id=\"T_02d85_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_02d85_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_02d85_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_02d85_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_02d85_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_02d85_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_02d85_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_02d85_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_02d85_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
||
" <td id=\"T_02d85_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_02d85_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_02d85_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_02d85_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_02d85_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_02d85_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_02d85_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_02d85_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x135c73282c0>"
|
||
]
|
||
},
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"optimized_metrics[\n",
|
||
" [\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" \"Accuracy_train\",\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_train\",\n",
|
||
" \"F1_test\",\n",
|
||
" ]\n",
|
||
"].style.background_gradient(\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 29,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_34fcd_row0_col0, #T_34fcd_row0_col1, #T_34fcd_row1_col0, #T_34fcd_row1_col1 {\n",
|
||
" background-color: #440154;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_34fcd_row0_col2, #T_34fcd_row0_col3, #T_34fcd_row0_col4, #T_34fcd_row1_col3, #T_34fcd_row1_col4 {\n",
|
||
" background-color: #0d0887;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_34fcd_row1_col2 {\n",
|
||
" background-color: #f0f921;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_34fcd\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_34fcd_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_34fcd_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
||
" <th id=\"T_34fcd_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
||
" <th id=\"T_34fcd_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
||
" <th id=\"T_34fcd_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th class=\"index_name level0\" >Name</th>\n",
|
||
" <th class=\"blank col0\" > </th>\n",
|
||
" <th class=\"blank col1\" > </th>\n",
|
||
" <th class=\"blank col2\" > </th>\n",
|
||
" <th class=\"blank col3\" > </th>\n",
|
||
" <th class=\"blank col4\" > </th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_34fcd_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
||
" <td id=\"T_34fcd_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_34fcd_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_34fcd_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_34fcd_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_34fcd_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_34fcd_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
||
" <td id=\"T_34fcd_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_34fcd_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_34fcd_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_34fcd_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_34fcd_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x135c7328590>"
|
||
]
|
||
},
|
||
"execution_count": 29,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"optimized_metrics[\n",
|
||
" [\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_test\",\n",
|
||
" \"ROC_AUC_test\",\n",
|
||
" \"Cohen_kappa_test\",\n",
|
||
" \"MCC_test\",\n",
|
||
" ]\n",
|
||
"].style.background_gradient(\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" subset=[\n",
|
||
" \"ROC_AUC_test\",\n",
|
||
" \"MCC_test\",\n",
|
||
" \"Cohen_kappa_test\",\n",
|
||
" ],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 30,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 1000x400 with 4 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
|
||
")\n",
|
||
"\n",
|
||
"for index in range(0, len(optimized_metrics)):\n",
|
||
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
|
||
" disp = ConfusionMatrixDisplay(\n",
|
||
" confusion_matrix=c_matrix, display_labels=[\"No Diabetes\", \"Diabetes\"]\n",
|
||
" ).plot(ax=ax.flat[index])\n",
|
||
"\n",
|
||
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Значение 100 означает количество верно классифицированных объектов, которые относятся к классу \"No Diabetes\". Можно сделать вывод, что модель отлично идентифицирует объекты этого класса.\n",
|
||
"\n",
|
||
"Значения 0 означают количество верно классифицированных объектов, которые относятся к классу \"Diabetes\". Можно сделать вывод, что это не высокая точность модели в определении объектов данного класса"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Определение достижимого уровня качества модели для второй задачи (задача регрессии)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 31,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"(700, 8)\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Pregnancies</th>\n",
|
||
" <th>Glucose</th>\n",
|
||
" <th>BloodPressure</th>\n",
|
||
" <th>SkinThickness</th>\n",
|
||
" <th>Insulin</th>\n",
|
||
" <th>BMI</th>\n",
|
||
" <th>DiabetesPedigreeFunction</th>\n",
|
||
" <th>Age</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>6</td>\n",
|
||
" <td>98</td>\n",
|
||
" <td>58</td>\n",
|
||
" <td>33</td>\n",
|
||
" <td>190</td>\n",
|
||
" <td>34.0</td>\n",
|
||
" <td>0.430</td>\n",
|
||
" <td>43</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>112</td>\n",
|
||
" <td>75</td>\n",
|
||
" <td>32</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>35.7</td>\n",
|
||
" <td>0.148</td>\n",
|
||
" <td>21</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>108</td>\n",
|
||
" <td>64</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>30.8</td>\n",
|
||
" <td>0.158</td>\n",
|
||
" <td>21</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>8</td>\n",
|
||
" <td>107</td>\n",
|
||
" <td>80</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>24.6</td>\n",
|
||
" <td>0.856</td>\n",
|
||
" <td>34</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>7</td>\n",
|
||
" <td>136</td>\n",
|
||
" <td>90</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>29.9</td>\n",
|
||
" <td>0.210</td>\n",
|
||
" <td>50</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>695</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>105</td>\n",
|
||
" <td>80</td>\n",
|
||
" <td>45</td>\n",
|
||
" <td>191</td>\n",
|
||
" <td>33.7</td>\n",
|
||
" <td>0.711</td>\n",
|
||
" <td>29</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>696</th>\n",
|
||
" <td>1</td>\n",
|
||
" <td>126</td>\n",
|
||
" <td>56</td>\n",
|
||
" <td>29</td>\n",
|
||
" <td>152</td>\n",
|
||
" <td>28.7</td>\n",
|
||
" <td>0.801</td>\n",
|
||
" <td>21</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>697</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>95</td>\n",
|
||
" <td>54</td>\n",
|
||
" <td>14</td>\n",
|
||
" <td>88</td>\n",
|
||
" <td>26.1</td>\n",
|
||
" <td>0.748</td>\n",
|
||
" <td>22</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>698</th>\n",
|
||
" <td>3</td>\n",
|
||
" <td>100</td>\n",
|
||
" <td>68</td>\n",
|
||
" <td>23</td>\n",
|
||
" <td>81</td>\n",
|
||
" <td>31.6</td>\n",
|
||
" <td>0.949</td>\n",
|
||
" <td>28</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>699</th>\n",
|
||
" <td>1</td>\n",
|
||
" <td>85</td>\n",
|
||
" <td>66</td>\n",
|
||
" <td>29</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>26.6</td>\n",
|
||
" <td>0.351</td>\n",
|
||
" <td>31</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>700 rows × 8 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
||
"0 6 98 58 33 190 34.0 \n",
|
||
"1 2 112 75 32 0 35.7 \n",
|
||
"2 2 108 64 0 0 30.8 \n",
|
||
"3 8 107 80 0 0 24.6 \n",
|
||
"4 7 136 90 0 0 29.9 \n",
|
||
".. ... ... ... ... ... ... \n",
|
||
"695 2 105 80 45 191 33.7 \n",
|
||
"696 1 126 56 29 152 28.7 \n",
|
||
"697 2 95 54 14 88 26.1 \n",
|
||
"698 3 100 68 23 81 31.6 \n",
|
||
"699 1 85 66 29 0 26.6 \n",
|
||
"\n",
|
||
" DiabetesPedigreeFunction Age \n",
|
||
"0 0.430 43 \n",
|
||
"1 0.148 21 \n",
|
||
"2 0.158 21 \n",
|
||
"3 0.856 34 \n",
|
||
"4 0.210 50 \n",
|
||
".. ... ... \n",
|
||
"695 0.711 29 \n",
|
||
"696 0.801 21 \n",
|
||
"697 0.748 22 \n",
|
||
"698 0.949 28 \n",
|
||
"699 0.351 31 \n",
|
||
"\n",
|
||
"[700 rows x 8 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import seaborn as sns\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"from sklearn import set_config\n",
|
||
"\n",
|
||
"random_state = 42\n",
|
||
"set_config(transform_output=\"pandas\")\n",
|
||
"\n",
|
||
"df = pd.read_csv(\"data/diabetes.csv\")\n",
|
||
"\n",
|
||
"df = df.drop(columns=[\"Outcome\"])\n",
|
||
"\n",
|
||
"df = df.sample(n=700, random_state=random_state).reset_index(drop=True)\n",
|
||
"\n",
|
||
"print(df.shape) \n",
|
||
"display(df)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 32,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
||
"0 6 148 72 35 0 33.6 \n",
|
||
"1 1 85 66 29 0 26.6 \n",
|
||
"2 8 183 64 0 0 23.3 \n",
|
||
"3 1 89 66 23 94 28.1 \n",
|
||
"4 0 137 40 35 168 43.1 \n",
|
||
"\n",
|
||
" DiabetesPedigreeFunction Age diabetes_risk_index \n",
|
||
"0 0.627 50 71.68 \n",
|
||
"1 0.351 31 46.28 \n",
|
||
"2 0.672 32 74.69 \n",
|
||
"3 0.167 21 55.33 \n",
|
||
"4 2.288 33 81.43 \n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"df = pd.read_csv(\"data/diabetes.csv\")\n",
|
||
"\n",
|
||
"required_columns = [\"Pregnancies\", \"Glucose\", \"BloodPressure\", \"SkinThickness\", \"Insulin\", \"BMI\", \"DiabetesPedigreeFunction\", \"Age\"]\n",
|
||
"missing_columns = [col for col in required_columns if col not in df.columns]\n",
|
||
"if missing_columns:\n",
|
||
" raise ValueError(f\"Отсутствуют столбцы: {missing_columns}\")\n",
|
||
"\n",
|
||
"df[\"diabetes_risk_index\"] = (\n",
|
||
" df[\"Glucose\"] * 0.3 \n",
|
||
" + df[\"BMI\"] * 0.3 \n",
|
||
" + df[\"Age\"] * 0.2 \n",
|
||
" + df[\"BloodPressure\"] * 0.1 \n",
|
||
" + df[\"Insulin\"] * 0.1 \n",
|
||
")\n",
|
||
"\n",
|
||
"print(df[[\"Pregnancies\", \"Glucose\", \"BloodPressure\", \"SkinThickness\", \"Insulin\", \"BMI\", \"DiabetesPedigreeFunction\", \"Age\", \"diabetes_risk_index\"]].head())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи регрессии"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 33,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Pregnancies</th>\n",
|
||
" <th>Glucose</th>\n",
|
||
" <th>BloodPressure</th>\n",
|
||
" <th>SkinThickness</th>\n",
|
||
" <th>Insulin</th>\n",
|
||
" <th>BMI</th>\n",
|
||
" <th>DiabetesPedigreeFunction</th>\n",
|
||
" <th>Age</th>\n",
|
||
" <th>Outcome</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>60</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>84</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.304</td>\n",
|
||
" <td>21</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>618</th>\n",
|
||
" <td>9</td>\n",
|
||
" <td>112</td>\n",
|
||
" <td>82</td>\n",
|
||
" <td>24</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>28.2</td>\n",
|
||
" <td>1.282</td>\n",
|
||
" <td>50</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>346</th>\n",
|
||
" <td>1</td>\n",
|
||
" <td>139</td>\n",
|
||
" <td>46</td>\n",
|
||
" <td>19</td>\n",
|
||
" <td>83</td>\n",
|
||
" <td>28.7</td>\n",
|
||
" <td>0.654</td>\n",
|
||
" <td>22</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>294</th>\n",
|
||
" <td>0</td>\n",
|
||
" <td>161</td>\n",
|
||
" <td>50</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>21.9</td>\n",
|
||
" <td>0.254</td>\n",
|
||
" <td>65</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>231</th>\n",
|
||
" <td>6</td>\n",
|
||
" <td>134</td>\n",
|
||
" <td>80</td>\n",
|
||
" <td>37</td>\n",
|
||
" <td>370</td>\n",
|
||
" <td>46.2</td>\n",
|
||
" <td>0.238</td>\n",
|
||
" <td>46</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
||
"60 2 84 0 0 0 0.0 \n",
|
||
"618 9 112 82 24 0 28.2 \n",
|
||
"346 1 139 46 19 83 28.7 \n",
|
||
"294 0 161 50 0 0 21.9 \n",
|
||
"231 6 134 80 37 370 46.2 \n",
|
||
"\n",
|
||
" DiabetesPedigreeFunction Age Outcome \n",
|
||
"60 0.304 21 0 \n",
|
||
"618 1.282 50 1 \n",
|
||
"346 0.654 22 0 \n",
|
||
"294 0.254 65 0 \n",
|
||
"231 0.238 46 1 "
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>diabetes_risk_index</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>60</th>\n",
|
||
" <td>29.40</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>618</th>\n",
|
||
" <td>60.26</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>346</th>\n",
|
||
" <td>67.61</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>294</th>\n",
|
||
" <td>72.87</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>231</th>\n",
|
||
" <td>108.26</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" diabetes_risk_index\n",
|
||
"60 29.40\n",
|
||
"618 60.26\n",
|
||
"346 67.61\n",
|
||
"294 72.87\n",
|
||
"231 108.26"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Pregnancies</th>\n",
|
||
" <th>Glucose</th>\n",
|
||
" <th>BloodPressure</th>\n",
|
||
" <th>SkinThickness</th>\n",
|
||
" <th>Insulin</th>\n",
|
||
" <th>BMI</th>\n",
|
||
" <th>DiabetesPedigreeFunction</th>\n",
|
||
" <th>Age</th>\n",
|
||
" <th>Outcome</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>668</th>\n",
|
||
" <td>6</td>\n",
|
||
" <td>98</td>\n",
|
||
" <td>58</td>\n",
|
||
" <td>33</td>\n",
|
||
" <td>190</td>\n",
|
||
" <td>34.0</td>\n",
|
||
" <td>0.430</td>\n",
|
||
" <td>43</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>324</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>112</td>\n",
|
||
" <td>75</td>\n",
|
||
" <td>32</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>35.7</td>\n",
|
||
" <td>0.148</td>\n",
|
||
" <td>21</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>624</th>\n",
|
||
" <td>2</td>\n",
|
||
" <td>108</td>\n",
|
||
" <td>64</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>30.8</td>\n",
|
||
" <td>0.158</td>\n",
|
||
" <td>21</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>690</th>\n",
|
||
" <td>8</td>\n",
|
||
" <td>107</td>\n",
|
||
" <td>80</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>24.6</td>\n",
|
||
" <td>0.856</td>\n",
|
||
" <td>34</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>473</th>\n",
|
||
" <td>7</td>\n",
|
||
" <td>136</td>\n",
|
||
" <td>90</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>29.9</td>\n",
|
||
" <td>0.210</td>\n",
|
||
" <td>50</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
||
"668 6 98 58 33 190 34.0 \n",
|
||
"324 2 112 75 32 0 35.7 \n",
|
||
"624 2 108 64 0 0 30.8 \n",
|
||
"690 8 107 80 0 0 24.6 \n",
|
||
"473 7 136 90 0 0 29.9 \n",
|
||
"\n",
|
||
" DiabetesPedigreeFunction Age Outcome \n",
|
||
"668 0.430 43 0 \n",
|
||
"324 0.148 21 0 \n",
|
||
"624 0.158 21 0 \n",
|
||
"690 0.856 34 0 \n",
|
||
"473 0.210 50 0 "
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>diabetes_risk_index</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>668</th>\n",
|
||
" <td>73.00</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>324</th>\n",
|
||
" <td>56.01</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>624</th>\n",
|
||
" <td>52.24</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>690</th>\n",
|
||
" <td>54.28</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>473</th>\n",
|
||
" <td>68.77</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" diabetes_risk_index\n",
|
||
"668 73.00\n",
|
||
"324 56.01\n",
|
||
"624 52.24\n",
|
||
"690 54.28\n",
|
||
"473 68.77"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from typing import Tuple\n",
|
||
"import pandas as pd\n",
|
||
"from pandas import DataFrame\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"\n",
|
||
"def split_into_train_test(\n",
|
||
" df_input: DataFrame,\n",
|
||
" target_colname: str = \"diabetes_risk_index\",\n",
|
||
" frac_train: float = 0.8,\n",
|
||
" random_state: int = None,\n",
|
||
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
||
" if not (0 < frac_train < 1):\n",
|
||
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
|
||
" \n",
|
||
" if target_colname not in df_input.columns:\n",
|
||
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
|
||
" \n",
|
||
" X = df_input.drop(columns=[target_colname])\n",
|
||
" y = df_input[[target_colname]]\n",
|
||
"\n",
|
||
" X_train, X_test, y_train, y_test = train_test_split(\n",
|
||
" X, y,\n",
|
||
" test_size=(1.0 - frac_train),\n",
|
||
" random_state=random_state\n",
|
||
" )\n",
|
||
" \n",
|
||
" return X_train, X_test, y_train, y_test\n",
|
||
"\n",
|
||
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
|
||
" df, \n",
|
||
" target_colname=\"diabetes_risk_index\", \n",
|
||
" frac_train=0.8, \n",
|
||
" random_state=42\n",
|
||
")\n",
|
||
"\n",
|
||
"display(\"X_train\", X_train.head())\n",
|
||
"display(\"y_train\", y_train.head())\n",
|
||
"\n",
|
||
"display(\"X_test\", X_test.head())\n",
|
||
"display(\"y_test\", y_test.head())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Определение перечня алгоритмов решения задачи аппроксимации (регрессии)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 34,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.pipeline import make_pipeline\n",
|
||
"from sklearn.preprocessing import PolynomialFeatures, StandardScaler\n",
|
||
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
|
||
"\n",
|
||
"random_state = 42\n",
|
||
"\n",
|
||
"models = {\n",
|
||
" # Линейная регрессия\n",
|
||
" \"linear\": {\n",
|
||
" \"model\": linear_model.LinearRegression(n_jobs=-1)\n",
|
||
" },\n",
|
||
" # Полиномиальная регрессия степени 2\n",
|
||
" \"linear_poly\": {\n",
|
||
" \"model\": make_pipeline(\n",
|
||
" PolynomialFeatures(degree=2),\n",
|
||
" StandardScaler(),\n",
|
||
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
||
" )\n",
|
||
" },\n",
|
||
" # Полиномиальная регрессия с взаимодействиями\n",
|
||
" \"linear_interact\": {\n",
|
||
" \"model\": make_pipeline(\n",
|
||
" PolynomialFeatures(interaction_only=True),\n",
|
||
" StandardScaler(),\n",
|
||
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
||
" )\n",
|
||
" },\n",
|
||
" # Ridge-регрессия\n",
|
||
" \"ridge\": {\n",
|
||
" \"model\": make_pipeline(\n",
|
||
" StandardScaler(),\n",
|
||
" linear_model.RidgeCV()\n",
|
||
" )\n",
|
||
" },\n",
|
||
" # Регрессия на основе дерева решений\n",
|
||
" \"decision_tree\": {\n",
|
||
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
|
||
" },\n",
|
||
" # Метод ближайших соседей (kNN)\n",
|
||
" \"knn\": {\n",
|
||
" \"model\": make_pipeline(\n",
|
||
" StandardScaler(),\n",
|
||
" neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)\n",
|
||
" )\n",
|
||
" },\n",
|
||
" # Случайный лес (Random Forest)\n",
|
||
" \"random_forest\": {\n",
|
||
" \"model\": ensemble.RandomForestRegressor(\n",
|
||
" max_depth=7, random_state=random_state, n_jobs=-1\n",
|
||
" )\n",
|
||
" },\n",
|
||
" # Нейронная сеть (MLPRegressor)\n",
|
||
" \"mlp\": {\n",
|
||
" \"model\": make_pipeline(\n",
|
||
" StandardScaler(),\n",
|
||
" neural_network.MLPRegressor(\n",
|
||
" activation=\"tanh\",\n",
|
||
" hidden_layer_sizes=(3,),\n",
|
||
" max_iter=500,\n",
|
||
" early_stopping=True,\n",
|
||
" random_state=random_state,\n",
|
||
" )\n",
|
||
" )\n",
|
||
" },\n",
|
||
"}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Обучение и оценка моделей с помощью различных алгоритмов"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 35,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Model: linear\n",
|
||
"Model: linear_poly\n",
|
||
"Model: linear_interact\n",
|
||
"Model: ridge\n",
|
||
"Model: decision_tree\n",
|
||
"Model: knn\n",
|
||
"Model: random_forest\n",
|
||
"Model: mlp\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"c:\\Users\\tabee\\AIM_PIbd-31_Tabeev_A.P\\.venv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:690: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (500) reached and the optimization hasn't converged yet.\n",
|
||
" warnings.warn(\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import math\n",
|
||
"from pandas import DataFrame\n",
|
||
"from sklearn import metrics\n",
|
||
"\n",
|
||
"for model_name in models.keys():\n",
|
||
" print(f\"Model: {model_name}\")\n",
|
||
"\n",
|
||
" fitted_model = models[model_name][\"model\"].fit(\n",
|
||
" X_train.values, y_train.values.ravel()\n",
|
||
" )\n",
|
||
" y_train_pred = fitted_model.predict(X_train.values)\n",
|
||
" y_test_pred = fitted_model.predict(X_test.values)\n",
|
||
" models[model_name][\"fitted\"] = fitted_model\n",
|
||
" models[model_name][\"train_preds\"] = y_train_pred\n",
|
||
" models[model_name][\"preds\"] = y_test_pred\n",
|
||
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
|
||
" metrics.mean_squared_error(y_train, y_train_pred)\n",
|
||
" )\n",
|
||
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
|
||
" metrics.mean_squared_error(y_test, y_test_pred)\n",
|
||
" )\n",
|
||
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
|
||
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
|
||
" )\n",
|
||
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Вывод результатов оценки"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 36,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_4f71e_row0_col0, #T_4f71e_row0_col1, #T_4f71e_row1_col0, #T_4f71e_row1_col1 {\n",
|
||
" background-color: #26818e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_4f71e_row0_col2, #T_4f71e_row1_col2, #T_4f71e_row7_col3 {\n",
|
||
" background-color: #4e02a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_4f71e_row0_col3, #T_4f71e_row1_col3, #T_4f71e_row2_col3, #T_4f71e_row3_col3, #T_4f71e_row7_col2 {\n",
|
||
" background-color: #da5a6a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_4f71e_row2_col0 {\n",
|
||
" background-color: #25838e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_4f71e_row2_col1 {\n",
|
||
" background-color: #25858e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_4f71e_row2_col2 {\n",
|
||
" background-color: #7100a8;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_4f71e_row3_col0 {\n",
|
||
" background-color: #25848e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_4f71e_row3_col1 {\n",
|
||
" background-color: #24878e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_4f71e_row3_col2 {\n",
|
||
" background-color: #7501a8;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_4f71e_row4_col0 {\n",
|
||
" background-color: #23888e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_4f71e_row4_col1 {\n",
|
||
" background-color: #23898e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_4f71e_row4_col2 {\n",
|
||
" background-color: #7a02a8;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_4f71e_row4_col3, #T_4f71e_row6_col2 {\n",
|
||
" background-color: #d9586a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_4f71e_row5_col0, #T_4f71e_row5_col1 {\n",
|
||
" background-color: #89d548;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_4f71e_row5_col2 {\n",
|
||
" background-color: #d24f71;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_4f71e_row5_col3 {\n",
|
||
" background-color: #7201a8;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_4f71e_row6_col0, #T_4f71e_row7_col0, #T_4f71e_row7_col1 {\n",
|
||
" background-color: #a8db34;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_4f71e_row6_col1 {\n",
|
||
" background-color: #a2da37;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_4f71e_row6_col3 {\n",
|
||
" background-color: #5302a3;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_4f71e\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_4f71e_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
|
||
" <th id=\"T_4f71e_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
|
||
" <th id=\"T_4f71e_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
|
||
" <th id=\"T_4f71e_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_4f71e_level0_row0\" class=\"row_heading level0 row0\" >linear</th>\n",
|
||
" <td id=\"T_4f71e_row0_col0\" class=\"data row0 col0\" >0.000000</td>\n",
|
||
" <td id=\"T_4f71e_row0_col1\" class=\"data row0 col1\" >0.000000</td>\n",
|
||
" <td id=\"T_4f71e_row0_col2\" class=\"data row0 col2\" >0.000000</td>\n",
|
||
" <td id=\"T_4f71e_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_4f71e_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
|
||
" <td id=\"T_4f71e_row1_col0\" class=\"data row1 col0\" >0.002409</td>\n",
|
||
" <td id=\"T_4f71e_row1_col1\" class=\"data row1 col1\" >0.002181</td>\n",
|
||
" <td id=\"T_4f71e_row1_col2\" class=\"data row1 col2\" >0.040847</td>\n",
|
||
" <td id=\"T_4f71e_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_4f71e_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
|
||
" <td id=\"T_4f71e_row2_col0\" class=\"data row2 col0\" >1.856691</td>\n",
|
||
" <td id=\"T_4f71e_row2_col1\" class=\"data row2 col1\" >3.387390</td>\n",
|
||
" <td id=\"T_4f71e_row2_col2\" class=\"data row2 col2\" >1.612410</td>\n",
|
||
" <td id=\"T_4f71e_row2_col3\" class=\"data row2 col3\" >0.967338</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_4f71e_level0_row3\" class=\"row_heading level0 row3\" >decision_tree</th>\n",
|
||
" <td id=\"T_4f71e_row3_col0\" class=\"data row3 col0\" >2.409987</td>\n",
|
||
" <td id=\"T_4f71e_row3_col1\" class=\"data row3 col1\" >4.281378</td>\n",
|
||
" <td id=\"T_4f71e_row3_col2\" class=\"data row3 col2\" >1.847379</td>\n",
|
||
" <td id=\"T_4f71e_row3_col3\" class=\"data row3 col3\" >0.947823</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_4f71e_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
|
||
" <td id=\"T_4f71e_row4_col0\" class=\"data row4 col0\" >5.170083</td>\n",
|
||
" <td id=\"T_4f71e_row4_col1\" class=\"data row4 col1\" >5.824576</td>\n",
|
||
" <td id=\"T_4f71e_row4_col2\" class=\"data row4 col2\" >2.073599</td>\n",
|
||
" <td id=\"T_4f71e_row4_col3\" class=\"data row4 col3\" >0.903431</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_4f71e_level0_row5\" class=\"row_heading level0 row5\" >mlp</th>\n",
|
||
" <td id=\"T_4f71e_row5_col0\" class=\"data row5 col0\" >60.594447</td>\n",
|
||
" <td id=\"T_4f71e_row5_col1\" class=\"data row5 col1\" >59.994001</td>\n",
|
||
" <td id=\"T_4f71e_row5_col2\" class=\"data row5 col2\" >7.553087</td>\n",
|
||
" <td id=\"T_4f71e_row5_col3\" class=\"data row5 col3\" >-9.245313</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_4f71e_level0_row6\" class=\"row_heading level0 row6\" >linear_poly</th>\n",
|
||
" <td id=\"T_4f71e_row6_col0\" class=\"data row6 col0\" >67.720820</td>\n",
|
||
" <td id=\"T_4f71e_row6_col1\" class=\"data row6 col1\" >66.518485</td>\n",
|
||
" <td id=\"T_4f71e_row6_col2\" class=\"data row6 col2\" >8.144355</td>\n",
|
||
" <td id=\"T_4f71e_row6_col3\" class=\"data row6 col3\" >-11.594887</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_4f71e_level0_row7\" class=\"row_heading level0 row7\" >linear_interact</th>\n",
|
||
" <td id=\"T_4f71e_row7_col0\" class=\"data row7 col0\" >67.518306</td>\n",
|
||
" <td id=\"T_4f71e_row7_col1\" class=\"data row7 col1\" >67.518306</td>\n",
|
||
" <td id=\"T_4f71e_row7_col2\" class=\"data row7 col2\" >8.216952</td>\n",
|
||
" <td id=\"T_4f71e_row7_col3\" class=\"data row7 col3\" >-11.976353</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x135c8320bc0>"
|
||
]
|
||
},
|
||
"execution_count": 36,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
|
||
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
|
||
"]\n",
|
||
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
|
||
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
|
||
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Вывод реального и \"спрогнозированного\" результата для обучающей и тестовой выборок\n",
|
||
"\n",
|
||
"Получение лучшей модели"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 37,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'linear'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
|
||
"\n",
|
||
"display(best_model)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Подбор гиперпараметров методом поиска по сетке"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 40,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Fitting 3 folds for each of 27 candidates, totalling 81 fits\n",
|
||
"Лучшие параметры: {'max_depth': 20, 'min_samples_split': 5, 'n_estimators': 150}\n",
|
||
"Лучший результат (MSE): 11.258082426680579\n",
|
||
"Ошибка на тестовой выборке (MSE): 9.1015\n",
|
||
"Коэффициент детерминации (R²): 0.9741\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"import numpy as np\n",
|
||
"from sklearn import metrics\n",
|
||
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
|
||
"from sklearn.ensemble import RandomForestRegressor\n",
|
||
"from sklearn.preprocessing import StandardScaler\n",
|
||
"\n",
|
||
"# Удаление пропущенных значений (если есть)\n",
|
||
"df.dropna(inplace=True)\n",
|
||
"\n",
|
||
"X = df[[\"Pregnancies\", \"Glucose\", \"BloodPressure\", \"SkinThickness\", \"Insulin\", \n",
|
||
" \"BMI\", \"DiabetesPedigreeFunction\", \"Age\"]] \n",
|
||
"y = df[\"diabetes_risk_index\"] \n",
|
||
"\n",
|
||
"X_train, X_test, y_train, y_test = train_test_split(\n",
|
||
" X, y, test_size=0.2, random_state=42\n",
|
||
")\n",
|
||
"\n",
|
||
"model = RandomForestRegressor(random_state=42)\n",
|
||
"\n",
|
||
"param_grid = {\n",
|
||
" 'n_estimators': [50, 100, 150], \n",
|
||
" 'max_depth': [10, 20, 30], \n",
|
||
" 'min_samples_split': [5, 10, 15] \n",
|
||
"}\n",
|
||
"\n",
|
||
"grid_search = GridSearchCV(\n",
|
||
" estimator=model,\n",
|
||
" param_grid=param_grid,\n",
|
||
" scoring='neg_mean_squared_error', \n",
|
||
" cv=3, \n",
|
||
" n_jobs=-1, \n",
|
||
" verbose=2 \n",
|
||
")\n",
|
||
"\n",
|
||
"grid_search.fit(X_train, y_train)\n",
|
||
"\n",
|
||
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
|
||
"print(\"Лучший результат (MSE):\", -grid_search.best_score_)\n",
|
||
"\n",
|
||
"best_model = grid_search.best_estimator_\n",
|
||
"y_pred = best_model.predict(X_test)\n",
|
||
"\n",
|
||
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
|
||
"r2 = metrics.r2_score(y_test, y_pred)\n",
|
||
"\n",
|
||
"print(f\"Ошибка на тестовой выборке (MSE): {mse:.4f}\")\n",
|
||
"print(f\"Коэффициент детерминации (R²): {r2:.4f}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Обучение модели с новыми гиперпараметрами и сравнение новых и старых данных"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 41,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Fitting 3 folds for each of 8 candidates, totalling 24 fits\n",
|
||
"Старые параметры: {'max_depth': 20, 'min_samples_split': 5, 'n_estimators': 100}\n",
|
||
"Лучший результат (MSE) на старых параметрах: 11.682596479300793\n",
|
||
"\n",
|
||
"\n",
|
||
"Новые параметры: {'max_depth': 20, 'min_samples_split': 10, 'n_estimators': 100}\n",
|
||
"Лучший результат (MSE) на новых параметрах: 18.55198575928597\n",
|
||
"Среднеквадратическая ошибка (MSE) на тестовых данных: 10.739599657760765\n",
|
||
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 3.27713284103052\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"import numpy as np\n",
|
||
"from sklearn import metrics\n",
|
||
"from sklearn.ensemble import RandomForestRegressor\n",
|
||
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"\n",
|
||
"\n",
|
||
"old_param_grid = {\n",
|
||
" 'n_estimators': [50, 100],\n",
|
||
" 'max_depth': [ 10, 20],\n",
|
||
" 'min_samples_split': [5, 10]\n",
|
||
"}\n",
|
||
"\n",
|
||
"old_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
|
||
" param_grid=old_param_grid,\n",
|
||
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
|
||
"\n",
|
||
"old_grid_search.fit(X_train, y_train)\n",
|
||
"\n",
|
||
"old_best_params = old_grid_search.best_params_\n",
|
||
"old_best_mse = -old_grid_search.best_score_\n",
|
||
"\n",
|
||
"new_param_grid = {\n",
|
||
" 'n_estimators': [100],\n",
|
||
" 'max_depth': [20],\n",
|
||
" 'min_samples_split': [10]\n",
|
||
"}\n",
|
||
"\n",
|
||
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
|
||
" param_grid=new_param_grid,\n",
|
||
" scoring='neg_mean_squared_error', cv=2)\n",
|
||
"\n",
|
||
"new_grid_search.fit(X_train, y_train)\n",
|
||
"\n",
|
||
"new_best_params = new_grid_search.best_params_\n",
|
||
"new_best_mse = -new_grid_search.best_score_\n",
|
||
"\n",
|
||
"model_best = RandomForestRegressor(**new_best_params)\n",
|
||
"model_best.fit(X_train, y_train)\n",
|
||
"\n",
|
||
"model_oldbest = RandomForestRegressor(**old_best_params)\n",
|
||
"model_oldbest.fit(X_train, y_train)\n",
|
||
"\n",
|
||
"y_pred = model_best.predict(X_test)\n",
|
||
"y_oldpred = model_oldbest.predict(X_test)\n",
|
||
"\n",
|
||
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
|
||
"rmse = np.sqrt(mse)\n",
|
||
"\n",
|
||
"print(\"Старые параметры:\", old_best_params)\n",
|
||
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
|
||
"print(\"\\n\")\n",
|
||
"print(\"Новые параметры:\", new_best_params)\n",
|
||
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
|
||
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
|
||
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 45,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 1000x500 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"plt.figure(figsize=(10, 5))\n",
|
||
"plt.scatter(range(len(y_test)), y_test, label=\"Актуальные значения\", color=\"green\", alpha=0.5)\n",
|
||
"plt.scatter(range(len(y_test)), y_pred, label=\"Новые параметры\", color=\"blue\", alpha=0.5)\n",
|
||
"plt.scatter(range(len(y_test)), y_oldpred, label=\"Старые параметры\", color=\"red\", alpha=0.5)\n",
|
||
"plt.xlabel(\"Выборка\")\n",
|
||
"plt.ylabel(\"Значения\")\n",
|
||
"plt.legend()\n",
|
||
"plt.title(\"Новые и старые параметры\")\n",
|
||
"plt.show()"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": ".venv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.12.5"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|