3832 lines
377 KiB
Plaintext
3832 lines
377 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Подключили датафрейм и выгрузили данные"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Index(['Pregnancies', 'Glucose', 'BloodPressure', 'SkinThickness', 'Insulin',\n",
|
|||
|
" 'BMI', 'DiabetesPedigreeFunction', 'Age', 'Outcome'],\n",
|
|||
|
" dtype='object')\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Pregnancies</th>\n",
|
|||
|
" <th>Glucose</th>\n",
|
|||
|
" <th>BloodPressure</th>\n",
|
|||
|
" <th>SkinThickness</th>\n",
|
|||
|
" <th>Insulin</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>DiabetesPedigreeFunction</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Outcome</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>6</td>\n",
|
|||
|
" <td>148</td>\n",
|
|||
|
" <td>72</td>\n",
|
|||
|
" <td>35</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>33.6</td>\n",
|
|||
|
" <td>0.627</td>\n",
|
|||
|
" <td>50</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>85</td>\n",
|
|||
|
" <td>66</td>\n",
|
|||
|
" <td>29</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>26.6</td>\n",
|
|||
|
" <td>0.351</td>\n",
|
|||
|
" <td>31</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>8</td>\n",
|
|||
|
" <td>183</td>\n",
|
|||
|
" <td>64</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>23.3</td>\n",
|
|||
|
" <td>0.672</td>\n",
|
|||
|
" <td>32</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>89</td>\n",
|
|||
|
" <td>66</td>\n",
|
|||
|
" <td>23</td>\n",
|
|||
|
" <td>94</td>\n",
|
|||
|
" <td>28.1</td>\n",
|
|||
|
" <td>0.167</td>\n",
|
|||
|
" <td>21</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>137</td>\n",
|
|||
|
" <td>40</td>\n",
|
|||
|
" <td>35</td>\n",
|
|||
|
" <td>168</td>\n",
|
|||
|
" <td>43.1</td>\n",
|
|||
|
" <td>2.288</td>\n",
|
|||
|
" <td>33</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>763</th>\n",
|
|||
|
" <td>10</td>\n",
|
|||
|
" <td>101</td>\n",
|
|||
|
" <td>76</td>\n",
|
|||
|
" <td>48</td>\n",
|
|||
|
" <td>180</td>\n",
|
|||
|
" <td>32.9</td>\n",
|
|||
|
" <td>0.171</td>\n",
|
|||
|
" <td>63</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>764</th>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>122</td>\n",
|
|||
|
" <td>70</td>\n",
|
|||
|
" <td>27</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>36.8</td>\n",
|
|||
|
" <td>0.340</td>\n",
|
|||
|
" <td>27</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>765</th>\n",
|
|||
|
" <td>5</td>\n",
|
|||
|
" <td>121</td>\n",
|
|||
|
" <td>72</td>\n",
|
|||
|
" <td>23</td>\n",
|
|||
|
" <td>112</td>\n",
|
|||
|
" <td>26.2</td>\n",
|
|||
|
" <td>0.245</td>\n",
|
|||
|
" <td>30</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>766</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>126</td>\n",
|
|||
|
" <td>60</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>30.1</td>\n",
|
|||
|
" <td>0.349</td>\n",
|
|||
|
" <td>47</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>767</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>93</td>\n",
|
|||
|
" <td>70</td>\n",
|
|||
|
" <td>31</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>30.4</td>\n",
|
|||
|
" <td>0.315</td>\n",
|
|||
|
" <td>23</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>768 rows × 9 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
|||
|
"0 6 148 72 35 0 33.6 \n",
|
|||
|
"1 1 85 66 29 0 26.6 \n",
|
|||
|
"2 8 183 64 0 0 23.3 \n",
|
|||
|
"3 1 89 66 23 94 28.1 \n",
|
|||
|
"4 0 137 40 35 168 43.1 \n",
|
|||
|
".. ... ... ... ... ... ... \n",
|
|||
|
"763 10 101 76 48 180 32.9 \n",
|
|||
|
"764 2 122 70 27 0 36.8 \n",
|
|||
|
"765 5 121 72 23 112 26.2 \n",
|
|||
|
"766 1 126 60 0 0 30.1 \n",
|
|||
|
"767 1 93 70 31 0 30.4 \n",
|
|||
|
"\n",
|
|||
|
" DiabetesPedigreeFunction Age Outcome \n",
|
|||
|
"0 0.627 50 1 \n",
|
|||
|
"1 0.351 31 0 \n",
|
|||
|
"2 0.672 32 1 \n",
|
|||
|
"3 0.167 21 0 \n",
|
|||
|
"4 2.288 33 1 \n",
|
|||
|
".. ... ... ... \n",
|
|||
|
"763 0.171 63 0 \n",
|
|||
|
"764 0.340 27 0 \n",
|
|||
|
"765 0.245 30 0 \n",
|
|||
|
"766 0.349 47 1 \n",
|
|||
|
"767 0.315 23 0 \n",
|
|||
|
"\n",
|
|||
|
"[768 rows x 9 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"from sklearn import set_config\n",
|
|||
|
"set_config(transform_output=\"pandas\")\n",
|
|||
|
"random_state = 42\n",
|
|||
|
"df = pd.read_csv(\"data/diabetes.csv\")\n",
|
|||
|
"print(df.columns)\n",
|
|||
|
"df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Бизнес-цели\n",
|
|||
|
"1. Предсказание риска развития диабета. Будем классифицировать пациентов на основе медданных для того, чтобы определить у кого есть риск развития диабета(будем использовать целевой признак \"Outcome\"). Актуальность для раннего выявления диабета.\n",
|
|||
|
"2. Анализ ключевых факторов, влияющих на диабет. Предсказание вероятности развития диабета на основе медданных. Актуальность для планирвоания лечения.\n",
|
|||
|
"## Определение достижимого уровня качества модели для первой задачи\n",
|
|||
|
"Разделение данных на обучающую и тестовые выборки 80/20 для задачи классификации"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Pregnancies</th>\n",
|
|||
|
" <th>Glucose</th>\n",
|
|||
|
" <th>BloodPressure</th>\n",
|
|||
|
" <th>SkinThickness</th>\n",
|
|||
|
" <th>Insulin</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>DiabetesPedigreeFunction</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Outcome</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>353</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>90</td>\n",
|
|||
|
" <td>62</td>\n",
|
|||
|
" <td>12</td>\n",
|
|||
|
" <td>43</td>\n",
|
|||
|
" <td>27.2</td>\n",
|
|||
|
" <td>0.580</td>\n",
|
|||
|
" <td>24</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>711</th>\n",
|
|||
|
" <td>5</td>\n",
|
|||
|
" <td>126</td>\n",
|
|||
|
" <td>78</td>\n",
|
|||
|
" <td>27</td>\n",
|
|||
|
" <td>22</td>\n",
|
|||
|
" <td>29.6</td>\n",
|
|||
|
" <td>0.439</td>\n",
|
|||
|
" <td>40</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>373</th>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>105</td>\n",
|
|||
|
" <td>58</td>\n",
|
|||
|
" <td>40</td>\n",
|
|||
|
" <td>94</td>\n",
|
|||
|
" <td>34.9</td>\n",
|
|||
|
" <td>0.225</td>\n",
|
|||
|
" <td>25</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>46</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>146</td>\n",
|
|||
|
" <td>56</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>29.7</td>\n",
|
|||
|
" <td>0.564</td>\n",
|
|||
|
" <td>29</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>682</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>95</td>\n",
|
|||
|
" <td>64</td>\n",
|
|||
|
" <td>39</td>\n",
|
|||
|
" <td>105</td>\n",
|
|||
|
" <td>44.6</td>\n",
|
|||
|
" <td>0.366</td>\n",
|
|||
|
" <td>22</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>451</th>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>134</td>\n",
|
|||
|
" <td>70</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>28.9</td>\n",
|
|||
|
" <td>0.542</td>\n",
|
|||
|
" <td>23</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>113</th>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>76</td>\n",
|
|||
|
" <td>62</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>34.0</td>\n",
|
|||
|
" <td>0.391</td>\n",
|
|||
|
" <td>25</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>556</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>97</td>\n",
|
|||
|
" <td>70</td>\n",
|
|||
|
" <td>40</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>38.1</td>\n",
|
|||
|
" <td>0.218</td>\n",
|
|||
|
" <td>30</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>667</th>\n",
|
|||
|
" <td>10</td>\n",
|
|||
|
" <td>111</td>\n",
|
|||
|
" <td>70</td>\n",
|
|||
|
" <td>27</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>27.5</td>\n",
|
|||
|
" <td>0.141</td>\n",
|
|||
|
" <td>40</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>107</th>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>144</td>\n",
|
|||
|
" <td>58</td>\n",
|
|||
|
" <td>28</td>\n",
|
|||
|
" <td>140</td>\n",
|
|||
|
" <td>29.5</td>\n",
|
|||
|
" <td>0.287</td>\n",
|
|||
|
" <td>37</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>614 rows × 9 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
|||
|
"353 1 90 62 12 43 27.2 \n",
|
|||
|
"711 5 126 78 27 22 29.6 \n",
|
|||
|
"373 2 105 58 40 94 34.9 \n",
|
|||
|
"46 1 146 56 0 0 29.7 \n",
|
|||
|
"682 0 95 64 39 105 44.6 \n",
|
|||
|
".. ... ... ... ... ... ... \n",
|
|||
|
"451 2 134 70 0 0 28.9 \n",
|
|||
|
"113 4 76 62 0 0 34.0 \n",
|
|||
|
"556 1 97 70 40 0 38.1 \n",
|
|||
|
"667 10 111 70 27 0 27.5 \n",
|
|||
|
"107 4 144 58 28 140 29.5 \n",
|
|||
|
"\n",
|
|||
|
" DiabetesPedigreeFunction Age Outcome \n",
|
|||
|
"353 0.580 24 0 \n",
|
|||
|
"711 0.439 40 0 \n",
|
|||
|
"373 0.225 25 0 \n",
|
|||
|
"46 0.564 29 0 \n",
|
|||
|
"682 0.366 22 0 \n",
|
|||
|
".. ... ... ... \n",
|
|||
|
"451 0.542 23 1 \n",
|
|||
|
"113 0.391 25 0 \n",
|
|||
|
"556 0.218 30 0 \n",
|
|||
|
"667 0.141 40 1 \n",
|
|||
|
"107 0.287 37 0 \n",
|
|||
|
"\n",
|
|||
|
"[614 rows x 9 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Outcome</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>353</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>711</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>373</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>46</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>682</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>451</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>113</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>556</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>667</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>107</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>614 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Outcome\n",
|
|||
|
"353 0\n",
|
|||
|
"711 0\n",
|
|||
|
"373 0\n",
|
|||
|
"46 0\n",
|
|||
|
"682 0\n",
|
|||
|
".. ...\n",
|
|||
|
"451 1\n",
|
|||
|
"113 0\n",
|
|||
|
"556 0\n",
|
|||
|
"667 1\n",
|
|||
|
"107 0\n",
|
|||
|
"\n",
|
|||
|
"[614 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Pregnancies</th>\n",
|
|||
|
" <th>Glucose</th>\n",
|
|||
|
" <th>BloodPressure</th>\n",
|
|||
|
" <th>SkinThickness</th>\n",
|
|||
|
" <th>Insulin</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>DiabetesPedigreeFunction</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Outcome</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>44</th>\n",
|
|||
|
" <td>7</td>\n",
|
|||
|
" <td>159</td>\n",
|
|||
|
" <td>64</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>27.4</td>\n",
|
|||
|
" <td>0.294</td>\n",
|
|||
|
" <td>40</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>672</th>\n",
|
|||
|
" <td>10</td>\n",
|
|||
|
" <td>68</td>\n",
|
|||
|
" <td>106</td>\n",
|
|||
|
" <td>23</td>\n",
|
|||
|
" <td>49</td>\n",
|
|||
|
" <td>35.5</td>\n",
|
|||
|
" <td>0.285</td>\n",
|
|||
|
" <td>47</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>700</th>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>122</td>\n",
|
|||
|
" <td>76</td>\n",
|
|||
|
" <td>27</td>\n",
|
|||
|
" <td>200</td>\n",
|
|||
|
" <td>35.9</td>\n",
|
|||
|
" <td>0.483</td>\n",
|
|||
|
" <td>26</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>630</th>\n",
|
|||
|
" <td>7</td>\n",
|
|||
|
" <td>114</td>\n",
|
|||
|
" <td>64</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>27.4</td>\n",
|
|||
|
" <td>0.732</td>\n",
|
|||
|
" <td>34</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>81</th>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>74</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.102</td>\n",
|
|||
|
" <td>22</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>32</th>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>88</td>\n",
|
|||
|
" <td>58</td>\n",
|
|||
|
" <td>11</td>\n",
|
|||
|
" <td>54</td>\n",
|
|||
|
" <td>24.8</td>\n",
|
|||
|
" <td>0.267</td>\n",
|
|||
|
" <td>22</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>637</th>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>94</td>\n",
|
|||
|
" <td>76</td>\n",
|
|||
|
" <td>18</td>\n",
|
|||
|
" <td>66</td>\n",
|
|||
|
" <td>31.6</td>\n",
|
|||
|
" <td>0.649</td>\n",
|
|||
|
" <td>23</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>593</th>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>82</td>\n",
|
|||
|
" <td>52</td>\n",
|
|||
|
" <td>22</td>\n",
|
|||
|
" <td>115</td>\n",
|
|||
|
" <td>28.5</td>\n",
|
|||
|
" <td>1.699</td>\n",
|
|||
|
" <td>25</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>425</th>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>184</td>\n",
|
|||
|
" <td>78</td>\n",
|
|||
|
" <td>39</td>\n",
|
|||
|
" <td>277</td>\n",
|
|||
|
" <td>37.0</td>\n",
|
|||
|
" <td>0.264</td>\n",
|
|||
|
" <td>31</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>273</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>71</td>\n",
|
|||
|
" <td>78</td>\n",
|
|||
|
" <td>50</td>\n",
|
|||
|
" <td>45</td>\n",
|
|||
|
" <td>33.2</td>\n",
|
|||
|
" <td>0.422</td>\n",
|
|||
|
" <td>21</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>154 rows × 9 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
|||
|
"44 7 159 64 0 0 27.4 \n",
|
|||
|
"672 10 68 106 23 49 35.5 \n",
|
|||
|
"700 2 122 76 27 200 35.9 \n",
|
|||
|
"630 7 114 64 0 0 27.4 \n",
|
|||
|
"81 2 74 0 0 0 0.0 \n",
|
|||
|
".. ... ... ... ... ... ... \n",
|
|||
|
"32 3 88 58 11 54 24.8 \n",
|
|||
|
"637 2 94 76 18 66 31.6 \n",
|
|||
|
"593 2 82 52 22 115 28.5 \n",
|
|||
|
"425 4 184 78 39 277 37.0 \n",
|
|||
|
"273 1 71 78 50 45 33.2 \n",
|
|||
|
"\n",
|
|||
|
" DiabetesPedigreeFunction Age Outcome \n",
|
|||
|
"44 0.294 40 0 \n",
|
|||
|
"672 0.285 47 0 \n",
|
|||
|
"700 0.483 26 0 \n",
|
|||
|
"630 0.732 34 1 \n",
|
|||
|
"81 0.102 22 0 \n",
|
|||
|
".. ... ... ... \n",
|
|||
|
"32 0.267 22 0 \n",
|
|||
|
"637 0.649 23 0 \n",
|
|||
|
"593 1.699 25 0 \n",
|
|||
|
"425 0.264 31 1 \n",
|
|||
|
"273 0.422 21 0 \n",
|
|||
|
"\n",
|
|||
|
"[154 rows x 9 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Outcome</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>44</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>672</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>700</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>630</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>81</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>32</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>637</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>593</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>425</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>273</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>154 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Outcome\n",
|
|||
|
"44 0\n",
|
|||
|
"672 0\n",
|
|||
|
"700 0\n",
|
|||
|
"630 1\n",
|
|||
|
"81 0\n",
|
|||
|
".. ...\n",
|
|||
|
"32 0\n",
|
|||
|
"637 0\n",
|
|||
|
"593 0\n",
|
|||
|
"425 1\n",
|
|||
|
"273 0\n",
|
|||
|
"\n",
|
|||
|
"[154 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from typing import Tuple\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"from pandas import DataFrame\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"\n",
|
|||
|
"random_state = 42\n",
|
|||
|
"\n",
|
|||
|
"def split_stratified_into_train_val_test(\n",
|
|||
|
" df_input,\n",
|
|||
|
" stratify_colname=\"y\",\n",
|
|||
|
" frac_train=0.6,\n",
|
|||
|
" frac_val=0.15,\n",
|
|||
|
" frac_test=0.25,\n",
|
|||
|
" random_state=None,\n",
|
|||
|
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
|||
|
" \n",
|
|||
|
" if frac_train + frac_val + frac_test != 1.0:\n",
|
|||
|
" raise ValueError(\n",
|
|||
|
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
|
|||
|
" % (frac_train, frac_val, frac_test)\n",
|
|||
|
" )\n",
|
|||
|
" if stratify_colname not in df_input.columns:\n",
|
|||
|
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
|
|||
|
" X = df_input\n",
|
|||
|
" y = df_input[\n",
|
|||
|
" [stratify_colname]\n",
|
|||
|
" ]\n",
|
|||
|
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
|
|||
|
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
|
|||
|
" )\n",
|
|||
|
" if frac_val <= 0:\n",
|
|||
|
" assert len(df_input) == len(df_train) + len(df_temp)\n",
|
|||
|
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
|
|||
|
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
|
|||
|
" df_val, df_test, y_val, y_test = train_test_split(\n",
|
|||
|
" df_temp,\n",
|
|||
|
" y_temp,\n",
|
|||
|
" stratify=y_temp,\n",
|
|||
|
" test_size=relative_frac_test,\n",
|
|||
|
" random_state=random_state,\n",
|
|||
|
" )\n",
|
|||
|
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
|
|||
|
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
|
|||
|
"\n",
|
|||
|
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
|
|||
|
" df, stratify_colname=\"Outcome\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_train\", X_train)\n",
|
|||
|
"display(\"y_train\", y_train)\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_test\", X_test)\n",
|
|||
|
"display(\"y_test\", y_test)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Формирование конвейера для классификации данных\n",
|
|||
|
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
|
|||
|
"\n",
|
|||
|
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
|
|||
|
"\n",
|
|||
|
"features_preprocessing -- трансформер для предобработки признаков\n",
|
|||
|
"\n",
|
|||
|
"features_engineering -- трансформер для конструирования признаков\n",
|
|||
|
"\n",
|
|||
|
"drop_columns -- трансформер для удаления колонок\n",
|
|||
|
"\n",
|
|||
|
"features_postprocessing -- трансформер для унитарного кодирования новых признаков\n",
|
|||
|
"\n",
|
|||
|
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков\n",
|
|||
|
"\n",
|
|||
|
"Конвейер выполняется последовательно.\n",
|
|||
|
"\n",
|
|||
|
"Трансформер выполняет параллельно для указанного набора колонок."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"from sklearn.base import BaseEstimator, TransformerMixin\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.preprocessing import OneHotEncoder, StandardScaler\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"\n",
|
|||
|
"# Построение конвейеров предобработки\n",
|
|||
|
"class DiabetesFeatures(BaseEstimator, TransformerMixin):\n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" pass\n",
|
|||
|
"\n",
|
|||
|
" def fit(self, X, y=None):\n",
|
|||
|
" return self\n",
|
|||
|
"\n",
|
|||
|
" def transform(self, X, y=None):\n",
|
|||
|
" # Добавим признак отношения индекса массы тела и возраста.\n",
|
|||
|
" X = X.copy()\n",
|
|||
|
" X[\"Glucose_Insulin\"] = X[\"Glucose\"] / X[\"Insulin\"]\n",
|
|||
|
" return X\n",
|
|||
|
"\n",
|
|||
|
" def get_feature_names_out(self, features_in):\n",
|
|||
|
" new_features = [\"Glucose_Insulin\"]\n",
|
|||
|
" return np.append(features_in, new_features, axis=0)\n",
|
|||
|
"\n",
|
|||
|
"# Обработка числовых данных. Числовой конвейр: заполнение пропущенных значений медианой и стандартизация\n",
|
|||
|
"preprocessing_num_class = Pipeline(steps=[\n",
|
|||
|
" ('imputer', SimpleImputer(strategy='median')),\n",
|
|||
|
" ('scaler', StandardScaler())\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"preprocessing_cat_class = Pipeline(steps=[\n",
|
|||
|
" ('imputer', SimpleImputer(strategy='most_frequent')),\n",
|
|||
|
" ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False, drop='first'))\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"columns_to_drop = []\n",
|
|||
|
"numeric_columns = [\"Pregnancies\", \"Glucose\", \"BloodPressure\", \"SkinThickness\", \"Insulin\",\n",
|
|||
|
" \"BMI\", \"DiabetesPedigreeFunction\", \"Age\"]\n",
|
|||
|
"cat_columns = [\"Outcome\"]\n",
|
|||
|
"\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"preprocessing_num\", preprocessing_num_class, numeric_columns),\n",
|
|||
|
" (\"preprocessing_cat\", preprocessing_cat_class, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\"\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"drop_columns = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\",\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"features_postprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" ('preprocessing_cat', preprocessing_cat_class, [\"Outcome\"]),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\",\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"pipeline_end = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" (\"custom_features\", DiabetesFeatures()),\n",
|
|||
|
" (\"drop_columns\", drop_columns),\n",
|
|||
|
" ]\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Работа конвейера:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Pregnancies</th>\n",
|
|||
|
" <th>Glucose</th>\n",
|
|||
|
" <th>BloodPressure</th>\n",
|
|||
|
" <th>SkinThickness</th>\n",
|
|||
|
" <th>Insulin</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>DiabetesPedigreeFunction</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Outcome_1</th>\n",
|
|||
|
" <th>Glucose_Insulin</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>353</th>\n",
|
|||
|
" <td>-0.851355</td>\n",
|
|||
|
" <td>-0.980131</td>\n",
|
|||
|
" <td>-0.404784</td>\n",
|
|||
|
" <td>-0.553973</td>\n",
|
|||
|
" <td>-0.331319</td>\n",
|
|||
|
" <td>-0.607678</td>\n",
|
|||
|
" <td>0.310794</td>\n",
|
|||
|
" <td>-0.792169</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>2.958266</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>711</th>\n",
|
|||
|
" <td>0.356576</td>\n",
|
|||
|
" <td>0.161444</td>\n",
|
|||
|
" <td>0.465368</td>\n",
|
|||
|
" <td>0.392787</td>\n",
|
|||
|
" <td>-0.526398</td>\n",
|
|||
|
" <td>-0.302139</td>\n",
|
|||
|
" <td>-0.116439</td>\n",
|
|||
|
" <td>0.561034</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>-0.306696</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>373</th>\n",
|
|||
|
" <td>-0.549372</td>\n",
|
|||
|
" <td>-0.504474</td>\n",
|
|||
|
" <td>-0.622322</td>\n",
|
|||
|
" <td>1.213312</td>\n",
|
|||
|
" <td>0.142444</td>\n",
|
|||
|
" <td>0.372594</td>\n",
|
|||
|
" <td>-0.764862</td>\n",
|
|||
|
" <td>-0.707594</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>-3.541575</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>46</th>\n",
|
|||
|
" <td>-0.851355</td>\n",
|
|||
|
" <td>0.795653</td>\n",
|
|||
|
" <td>-0.731091</td>\n",
|
|||
|
" <td>-1.311380</td>\n",
|
|||
|
" <td>-0.730766</td>\n",
|
|||
|
" <td>-0.289408</td>\n",
|
|||
|
" <td>0.262314</td>\n",
|
|||
|
" <td>-0.369293</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>-1.088792</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>682</th>\n",
|
|||
|
" <td>-1.153338</td>\n",
|
|||
|
" <td>-0.821579</td>\n",
|
|||
|
" <td>-0.296015</td>\n",
|
|||
|
" <td>1.150195</td>\n",
|
|||
|
" <td>0.244628</td>\n",
|
|||
|
" <td>1.607482</td>\n",
|
|||
|
" <td>-0.337630</td>\n",
|
|||
|
" <td>-0.961320</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>-3.358486</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>451</th>\n",
|
|||
|
" <td>-0.549372</td>\n",
|
|||
|
" <td>0.415128</td>\n",
|
|||
|
" <td>0.030292</td>\n",
|
|||
|
" <td>-1.311380</td>\n",
|
|||
|
" <td>-0.730766</td>\n",
|
|||
|
" <td>-0.391255</td>\n",
|
|||
|
" <td>0.195653</td>\n",
|
|||
|
" <td>-0.876744</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>-0.568071</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>113</th>\n",
|
|||
|
" <td>0.054593</td>\n",
|
|||
|
" <td>-1.424076</td>\n",
|
|||
|
" <td>-0.404784</td>\n",
|
|||
|
" <td>-1.311380</td>\n",
|
|||
|
" <td>-0.730766</td>\n",
|
|||
|
" <td>0.258017</td>\n",
|
|||
|
" <td>-0.261879</td>\n",
|
|||
|
" <td>-0.707594</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.948744</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>556</th>\n",
|
|||
|
" <td>-0.851355</td>\n",
|
|||
|
" <td>-0.758158</td>\n",
|
|||
|
" <td>0.030292</td>\n",
|
|||
|
" <td>1.213312</td>\n",
|
|||
|
" <td>-0.730766</td>\n",
|
|||
|
" <td>0.779980</td>\n",
|
|||
|
" <td>-0.786072</td>\n",
|
|||
|
" <td>-0.284718</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.037483</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>667</th>\n",
|
|||
|
" <td>1.866489</td>\n",
|
|||
|
" <td>-0.314212</td>\n",
|
|||
|
" <td>0.030292</td>\n",
|
|||
|
" <td>0.392787</td>\n",
|
|||
|
" <td>-0.730766</td>\n",
|
|||
|
" <td>-0.569486</td>\n",
|
|||
|
" <td>-1.019383</td>\n",
|
|||
|
" <td>0.561034</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.429976</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>107</th>\n",
|
|||
|
" <td>0.054593</td>\n",
|
|||
|
" <td>0.732232</td>\n",
|
|||
|
" <td>-0.622322</td>\n",
|
|||
|
" <td>0.455904</td>\n",
|
|||
|
" <td>0.569759</td>\n",
|
|||
|
" <td>-0.314870</td>\n",
|
|||
|
" <td>-0.577001</td>\n",
|
|||
|
" <td>0.307308</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.285160</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>614 rows × 10 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
|||
|
"353 -0.851355 -0.980131 -0.404784 -0.553973 -0.331319 -0.607678 \n",
|
|||
|
"711 0.356576 0.161444 0.465368 0.392787 -0.526398 -0.302139 \n",
|
|||
|
"373 -0.549372 -0.504474 -0.622322 1.213312 0.142444 0.372594 \n",
|
|||
|
"46 -0.851355 0.795653 -0.731091 -1.311380 -0.730766 -0.289408 \n",
|
|||
|
"682 -1.153338 -0.821579 -0.296015 1.150195 0.244628 1.607482 \n",
|
|||
|
".. ... ... ... ... ... ... \n",
|
|||
|
"451 -0.549372 0.415128 0.030292 -1.311380 -0.730766 -0.391255 \n",
|
|||
|
"113 0.054593 -1.424076 -0.404784 -1.311380 -0.730766 0.258017 \n",
|
|||
|
"556 -0.851355 -0.758158 0.030292 1.213312 -0.730766 0.779980 \n",
|
|||
|
"667 1.866489 -0.314212 0.030292 0.392787 -0.730766 -0.569486 \n",
|
|||
|
"107 0.054593 0.732232 -0.622322 0.455904 0.569759 -0.314870 \n",
|
|||
|
"\n",
|
|||
|
" DiabetesPedigreeFunction Age Outcome_1 Glucose_Insulin \n",
|
|||
|
"353 0.310794 -0.792169 0.0 2.958266 \n",
|
|||
|
"711 -0.116439 0.561034 0.0 -0.306696 \n",
|
|||
|
"373 -0.764862 -0.707594 0.0 -3.541575 \n",
|
|||
|
"46 0.262314 -0.369293 0.0 -1.088792 \n",
|
|||
|
"682 -0.337630 -0.961320 0.0 -3.358486 \n",
|
|||
|
".. ... ... ... ... \n",
|
|||
|
"451 0.195653 -0.876744 1.0 -0.568071 \n",
|
|||
|
"113 -0.261879 -0.707594 0.0 1.948744 \n",
|
|||
|
"556 -0.786072 -0.284718 0.0 1.037483 \n",
|
|||
|
"667 -1.019383 0.561034 1.0 0.429976 \n",
|
|||
|
"107 -0.577001 0.307308 0.0 1.285160 \n",
|
|||
|
"\n",
|
|||
|
"[614 rows x 10 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"preprocessed_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Формирование набора моделей для классификации\n",
|
|||
|
"logistic -- логистическая регрессия\n",
|
|||
|
"\n",
|
|||
|
"ridge -- гребневая регрессия\n",
|
|||
|
"\n",
|
|||
|
"decision_tree -- дерево решений\n",
|
|||
|
"\n",
|
|||
|
"knn -- k-ближайших соседей\n",
|
|||
|
"\n",
|
|||
|
"naive_bayes -- наивный Байесовский классификатор\n",
|
|||
|
"\n",
|
|||
|
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
|
|||
|
"\n",
|
|||
|
"random_forest -- метод случайного леса (набор деревьев решений)\n",
|
|||
|
"\n",
|
|||
|
"mlp -- многослойный персептрон (нейронная сеть)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
|
|||
|
"\n",
|
|||
|
"random_state = 42\n",
|
|||
|
"\n",
|
|||
|
"class_models = {\n",
|
|||
|
" \"logistic\": {\"model\": linear_model.LogisticRegression(random_state=random_state)},\n",
|
|||
|
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\", random_state=random_state)},\n",
|
|||
|
" \"decision_tree\": {\n",
|
|||
|
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
|
|||
|
" },\n",
|
|||
|
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
|
|||
|
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
|
|||
|
" \"gradient_boosting\": {\n",
|
|||
|
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210, random_state=random_state)\n",
|
|||
|
" },\n",
|
|||
|
" \"random_forest\": {\n",
|
|||
|
" \"model\": ensemble.RandomForestClassifier(\n",
|
|||
|
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" \"mlp\": {\n",
|
|||
|
" \"model\": neural_network.MLPClassifier(\n",
|
|||
|
" hidden_layer_sizes=(7,),\n",
|
|||
|
" max_iter=500,\n",
|
|||
|
" early_stopping=True,\n",
|
|||
|
" random_state=random_state,\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
"}"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обучение моделей на обучающем наборе данных и оценка на тестовом"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: logistic\n",
|
|||
|
"Model: ridge\n",
|
|||
|
"Model: decision_tree\n",
|
|||
|
"Model: knn\n",
|
|||
|
"Model: naive_bayes\n",
|
|||
|
"Model: gradient_boosting\n",
|
|||
|
"Model: random_forest\n",
|
|||
|
"Model: mlp\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"\n",
|
|||
|
"for model_name in class_models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" model = class_models[model_name][\"model\"]\n",
|
|||
|
"\n",
|
|||
|
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
|
|||
|
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
|
|||
|
"\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train)\n",
|
|||
|
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
|
|||
|
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
|
|||
|
"\n",
|
|||
|
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
|||
|
" class_models[model_name][\"probs\"] = y_test_probs\n",
|
|||
|
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
|||
|
"\n",
|
|||
|
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
|
|||
|
" y_test, y_test_probs\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Сводная таблица оценок качества для использованных моделей классификации"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA3EAAAQ9CAYAAAD3ScTVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVwU9f8H8NcAwiI3iCwoIt544pWRZ4ZhpWna1yz6hebR5Z1npRzelkl4lqVoaWalppWWt6lkeaaJNyomoImAoFy7n98f5NYKi7s6wM7wen4f8/i6n/nM7Gc24eV79jMzkhBCgIiIiIiIiBTBpqIHQEREREREROZjEUdERERERKQgLOKIiIiIiIgUhEUcERERERGRgrCIIyIiIiIiUhAWcURERERERArCIo6IiIiIiEhBWMQREREREREpCIs4IiIiIiIiBWERR5VefHw8JEnCxYsXy2T/Fy9ehCRJiI+Pl2V/u3btgiRJ2LVrlyz7IyIiUouoqChIkmRWX0mSEBUVVbYDIiojLOKIrNSiRYtkK/yIiIiISD3sKnoARGoXEBCAO3fuoEqVKhZtt2jRIlSrVg0DBgwwau/UqRPu3LkDe3t7GUdJRESkfO+99x4mTpxY0cMgKnMs4ojKmCRJ0Gg0su3PxsZG1v0RERGpQU5ODpycnGBnx3/ekvpxOiVRCRYtWoQmTZrAwcEBfn5+eOutt5CRkVGs38KFC1GnTh04OjrikUcewS+//IIuXbqgS5cuhj4lXROXmpqKgQMHombNmnBwcICvry969epluC6vdu3a+PPPP7F7925IkgRJkgz7NHVN3IEDB/D000/Dw8MDTk5OaN68OT766CN5PxgiIiIrcPfat5MnT+Kll16Ch4cHOnToUOI1cXl5eRg9ejS8vb3h4uKCZ599FleuXClxv7t27UKbNm2g0WhQt25dfPzxxyavs/viiy/QunVrODo6wtPTE/3790dycnKZHC/RvXiqgugeUVFRiI6ORmhoKN544w2cPn0aixcvxu+//459+/YZpkUuXrwYw4YNQ8eOHTF69GhcvHgRvXv3hoeHB2rWrFnqe/Tt2xd//vknhg8fjtq1a+PatWvYunUrLl++jNq1ayM2NhbDhw+Hs7Mz3n33XQCAj4+Pyf1t3boVPXr0gK+vL0aOHAmtVovExER8//33GDlypHwfDhERkRX53//+h/r162PGjBkQQuDatWvF+gwePBhffPEFXnrpJTz22GPYsWMHnnnmmWL9jhw5gu7du8PX1xfR0dHQ6XSIiYmBt7d3sb7Tp0/H5MmT0a9fPwwePBjXr1/H/Pnz0alTJxw5cgTu7u5lcbhE/xJEldzy5csFAJGUlCSuXbsm7O3txZNPPil0Op2hz4IFCwQAsWzZMiGEEHl5ecLLy0u0bdtWFBQUGPrFx8cLAKJz586GtqSkJAFALF++XAghxM2bNwUA8f7775c6riZNmhjt566dO3cKAGLnzp1CCCEKCwtFYGCgCAgIEDdv3jTqq9frzf8giIiIFCIyMlIAEC+++GKJ7XcdPXpUABBvvvmmUb+XXnpJABCRkZGGtp49e4qqVauKv/76y9B29uxZYWdnZ7TPixcvCltbWzF9+nSjfR4/flzY2dkVaycqC5xOSfQf27ZtQ35+PkaNGgUbm39/PIYMGQJXV1f88MMPAICDBw/ixo0bGDJkiNHc+/DwcHh4eJT6Ho6OjrC3t8euXbtw8+bNhx7zkSNHkJSUhFGjRhU782fubZaJiIiU6PXXXy91/Y8//ggAGDFihFH7qFGjjF7rdDps27YNvXv3hp+fn6G9Xr16eOqpp4z6rlu3Dnq9Hv369cPff/9tWLRaLerXr4+dO3c+xBERmYfTKYn+49KlSwCAhg0bGrXb29ujTp06hvV3/79evXpG/ezs7FC7du1S38PBwQGzZ8/G22+/DR8fHzz66KPo0aMHXnnlFWi1WovHfP78eQBA06ZNLd6WiIhIyQIDA0tdf+nSJdjY2KBu3bpG7ffm/LVr13Dnzp1iuQ4Uz/qzZ89CCIH69euX+J6W3o2a6EGwiCOqAKNGjULPnj2xYcMG/PTTT5g8eTJmzpyJHTt2oGXLlhU9PCIiIkVwdHQs9/fU6/WQJAmbN2+Gra1tsfXOzs7lPiaqfDidkug/AgICAACnT582as/Pz0dSUpJh/d3/P3funFG/wsJCwx0m76du3bp4++238fPPP+PEiRPIz8/H3LlzDevNnQp59+ziiRMnzOpPRERUWQQEBECv1xtmrdx1b85Xr14dGo2mWK4DxbO+bt26EEIgMDAQoaGhxZZHH31U/gMhugeLOKL/CA0Nhb29PeLi4iCEMLR/9tlnyMzMNNzNqk2bNvDy8sLSpUtRWFho6Ldq1ar7Xud2+/Zt5ObmGrXVrVsXLi4uyMvLM7Q5OTmV+FiDe7Vq1QqBgYGIjY0t1v+/x0BERFTZ3L2eLS4uzqg9NjbW6LWtrS1CQ0OxYcMGXL161dB+7tw5bN682ahvnz59YGtri+jo6GI5K4TAjRs3ZDwCopJxOiXRf3h7e2PSpEmIjo5G9+7d8eyzz+L06dNYtGgR2rZti5dffhlA0TVyUVFRGD58OLp27Yp+/frh4sWLiI+PR926dUv9Fu3MmTN44okn0K9fPzRu3Bh2dnZYv3490tLS0L9/f0O/1q1bY/HixZg2bRrq1auH6tWro2vXrsX2Z2Njg8WLF6Nnz54IDg7GwIED4evri1OnTuHPP//ETz/9JP8HRUREpADBwcF48cUXsWjRImRmZuKxxx7D9u3bS/zGLSoqCj///DPat2+PN954AzqdDgsWLEDTpk1x9OhRQ7+6deti2rRpmDRpkuHxQi4uLkhKSsL69esxdOhQjB07thyPkiojFnFE94iKioK3tzcWLFiA0aNHw9PTE0OHDsWMGTOMLlYeNmwYhBCYO3cuxo4dixYtWmDjxo0YMWIENBqNyf37+/vjxRdfxPbt2/H555/Dzs4OjRo1wtq1a9G3b19DvylTpuDSpUuYM2cObt26hc6dO5dYxAFAWFgYdu7ciejoaMydOxd6vR5169bFkCFD5PtgiIiIFGjZsmXw9vbGqlWrsGHDBnTt2hU//PAD/P39jfq1bt0amzdvxtixYzF58mT4+/sjJiYGiYmJOHXqlFHfiRMnokGDBpg3bx6io6MBFOX7k08+iWeffbbcjo0qL0lwvhWRbPR6Pby9vdGnTx8sXbq0oodDRERED6l37974888/cfbs2YoeCpEBr4kjekC5ubnF5sKvXLkS6enp6NKlS8UMioiIiB7YnTt3jF6fPXsWP/74I3OdrA6/iSN6QLt27cLo0aPxv//9D15eXjh8+DA+++wzBAUF4dChQ7C3t6/oIRIREZEFfH19MWDAAMOzYRcvXoy8vDwcOXLE5HPhiCoCr4kjekC1a9eGv78/4uLikJ6eDk9PT7zyyiuYNWsWCzgiIiIF6t69O7788kukpqbCwcEBISEhmDFjBgs4sjqcTkn0gGrXro2NGzciNTUV+fn5SE1NxbJly1C9evWKHhqpxJ49e9CzZ0/4+flBkiRs2LDBaL0QAlOmTIGvry8cHR0RGhpa7JqN9PR0hIeHw9XVFe7u7hg0aBCys7PL8SiIiJRj+fLluHjxInJzc5GZmYktW7agVatWFT0ssiLWks0s4oiIrFROTg5atGiBhQsXlrh+zpw5iIuLw5IlS3DgwAE4OTkhLCzM6DmE4eHh+PPPP7F161Z8//332LNnD4YOHVpeh0BERKQq1pLNvCaOiEgBJEnC+vXr0bt3bwBFZ/r8/Pzw9ttvG55HlJmZCR8fH8THx6N///5ITExE48aN8fvvv6NNmzYAgC1btuDpp5/GlStX4OfnV1GHQ0REpHgVmc28Jq4S0ev1uHr1KlxcXEp9GDWRWgkhcOvWLfj5+cHGRt6JCLm5ucjPz7/v+9/7s+fg4AAHBweL3y8pKQmpqakIDQ01tLm5uaFdu3ZISEhA//79kZCQAHd3d0NIAEBoaChsbGxw4MABPPfccxa/LxHJi9lMlR2z+cGymUVcJXL16tV
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1200x1000 with 16 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"n_rows = int(len(class_models) / 2)\n",
|
|||
|
"n_cols = 2\n",
|
|||
|
"\n",
|
|||
|
"fig, ax = plt.subplots(n_rows, n_cols, figsize=(12, 10), sharex=False, sharey=False)\n",
|
|||
|
"\n",
|
|||
|
"for index, key in enumerate(class_models.keys()):\n",
|
|||
|
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[\"No Diabetes\", \"Diabetes\"]\n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
" disp.ax_.set_title(key)\n",
|
|||
|
"\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Значение 100 - это количество вернных диагнозов (True Positives), там модель верно определила людей у которых нет диабета \"No Diabetes\".\n",
|
|||
|
"\n",
|
|||
|
"Значение 54 у некоторых моделей - это количество неверных диагнозов (False Negatives), там модель неверно определила людей с диабетом, те, у кого нет диабета \"No Diabetes\" были отнесены к классу \"Diabetes\".\n",
|
|||
|
"\n",
|
|||
|
"Исходя из истинных и ложных значений (True Positives и False Negatives), можно сделать вывод, что модель имеет высокую точность при предсказании людей без диабета \"No Diabetes\". Уровень ложных результатов в некотоорых моделях со значением 54 говорит о том, что есть такие данные, которые модель пропускает.\n",
|
|||
|
"\n",
|
|||
|
"Точность, полнота, верность (аккуратность), F-мера"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_fed29_row0_col0, #T_fed29_row0_col1, #T_fed29_row0_col2, #T_fed29_row0_col3, #T_fed29_row1_col0, #T_fed29_row1_col1, #T_fed29_row1_col2, #T_fed29_row1_col3, #T_fed29_row2_col0, #T_fed29_row2_col1, #T_fed29_row2_col2, #T_fed29_row2_col3, #T_fed29_row3_col0, #T_fed29_row3_col1, #T_fed29_row3_col2, #T_fed29_row3_col3, #T_fed29_row4_col0, #T_fed29_row4_col1, #T_fed29_row4_col2, #T_fed29_row4_col3, #T_fed29_row5_col0, #T_fed29_row5_col1, #T_fed29_row5_col2, #T_fed29_row5_col3 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_fed29_row0_col4, #T_fed29_row0_col5, #T_fed29_row0_col6, #T_fed29_row0_col7, #T_fed29_row1_col4, #T_fed29_row1_col5, #T_fed29_row1_col6, #T_fed29_row1_col7, #T_fed29_row2_col4, #T_fed29_row2_col5, #T_fed29_row2_col6, #T_fed29_row2_col7, #T_fed29_row3_col4, #T_fed29_row3_col5, #T_fed29_row3_col6, #T_fed29_row3_col7, #T_fed29_row4_col4, #T_fed29_row4_col5, #T_fed29_row4_col6, #T_fed29_row4_col7, #T_fed29_row5_col4, #T_fed29_row5_col5, #T_fed29_row5_col6, #T_fed29_row5_col7 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_fed29_row6_col0 {\n",
|
|||
|
" background-color: #86d549;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_fed29_row6_col1 {\n",
|
|||
|
" background-color: #4ac16d;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_fed29_row6_col2 {\n",
|
|||
|
" background-color: #75d054;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_fed29_row6_col3 {\n",
|
|||
|
" background-color: #69cd5b;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_fed29_row6_col4 {\n",
|
|||
|
" background-color: #c43e7f;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_fed29_row6_col5 {\n",
|
|||
|
" background-color: #ae2892;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_fed29_row6_col6 {\n",
|
|||
|
" background-color: #cc4977;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_fed29_row6_col7 {\n",
|
|||
|
" background-color: #c13b82;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_fed29_row7_col0, #T_fed29_row7_col1, #T_fed29_row7_col2, #T_fed29_row7_col3 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_fed29_row7_col4, #T_fed29_row7_col5, #T_fed29_row7_col6, #T_fed29_row7_col7 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_fed29\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_fed29_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
|||
|
" <th id=\"T_fed29_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
|||
|
" <th id=\"T_fed29_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
|||
|
" <th id=\"T_fed29_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
|||
|
" <th id=\"T_fed29_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
|||
|
" <th id=\"T_fed29_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_fed29_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
|||
|
" <th id=\"T_fed29_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_fed29_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
|
|||
|
" <td id=\"T_fed29_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_fed29_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
|
|||
|
" <td id=\"T_fed29_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_fed29_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
|
|||
|
" <td id=\"T_fed29_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_fed29_level0_row3\" class=\"row_heading level0 row3\" >naive_bayes</th>\n",
|
|||
|
" <td id=\"T_fed29_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row3_col5\" class=\"data row3 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row3_col6\" class=\"data row3 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row3_col7\" class=\"data row3 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_fed29_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_fed29_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row4_col5\" class=\"data row4 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row4_col6\" class=\"data row4 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row4_col7\" class=\"data row4 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_fed29_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
|
|||
|
" <td id=\"T_fed29_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row5_col5\" class=\"data row5 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row5_col6\" class=\"data row5 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_fed29_row5_col7\" class=\"data row5 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_fed29_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
|
|||
|
" <td id=\"T_fed29_row6_col0\" class=\"data row6 col0\" >0.947090</td>\n",
|
|||
|
" <td id=\"T_fed29_row6_col1\" class=\"data row6 col1\" >0.796296</td>\n",
|
|||
|
" <td id=\"T_fed29_row6_col2\" class=\"data row6 col2\" >0.836449</td>\n",
|
|||
|
" <td id=\"T_fed29_row6_col3\" class=\"data row6 col3\" >0.796296</td>\n",
|
|||
|
" <td id=\"T_fed29_row6_col4\" class=\"data row6 col4\" >0.926710</td>\n",
|
|||
|
" <td id=\"T_fed29_row6_col5\" class=\"data row6 col5\" >0.857143</td>\n",
|
|||
|
" <td id=\"T_fed29_row6_col6\" class=\"data row6 col6\" >0.888337</td>\n",
|
|||
|
" <td id=\"T_fed29_row6_col7\" class=\"data row6 col7\" >0.796296</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_fed29_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
|
|||
|
" <td id=\"T_fed29_row7_col0\" class=\"data row7 col0\" >0.552632</td>\n",
|
|||
|
" <td id=\"T_fed29_row7_col1\" class=\"data row7 col1\" >0.428571</td>\n",
|
|||
|
" <td id=\"T_fed29_row7_col2\" class=\"data row7 col2\" >0.098131</td>\n",
|
|||
|
" <td id=\"T_fed29_row7_col3\" class=\"data row7 col3\" >0.111111</td>\n",
|
|||
|
" <td id=\"T_fed29_row7_col4\" class=\"data row7 col4\" >0.657980</td>\n",
|
|||
|
" <td id=\"T_fed29_row7_col5\" class=\"data row7 col5\" >0.636364</td>\n",
|
|||
|
" <td id=\"T_fed29_row7_col6\" class=\"data row7 col6\" >0.166667</td>\n",
|
|||
|
" <td id=\"T_fed29_row7_col7\" class=\"data row7 col7\" >0.176471</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x135c39b70b0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 14,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
|||
|
" [\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" \"Accuracy_train\",\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_train\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]\n",
|
|||
|
"class_metrics.sort_values(\n",
|
|||
|
" by=\"Accuracy_test\", ascending=False\n",
|
|||
|
").style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Почти все модели в данной выборке, а именно логистическая регрессия, ридж-регрессия, дерево решений, наивный байесовский классификатор, случайный лес, градиентный бустинг, KNN демонстрируют неплохие значения по всем метрикам на обучающих и тестовых наборах данных. Модель MLP не так эффективна по сравнению с другими, но в некоторых метриках показывает неплохие результаты."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_2bffc_row0_col0, #T_2bffc_row0_col1, #T_2bffc_row1_col0, #T_2bffc_row1_col1, #T_2bffc_row2_col0, #T_2bffc_row2_col1, #T_2bffc_row3_col0, #T_2bffc_row3_col1, #T_2bffc_row4_col0, #T_2bffc_row4_col1, #T_2bffc_row5_col0, #T_2bffc_row5_col1 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_2bffc_row0_col2, #T_2bffc_row0_col3, #T_2bffc_row0_col4, #T_2bffc_row1_col2, #T_2bffc_row1_col3, #T_2bffc_row1_col4, #T_2bffc_row2_col2, #T_2bffc_row2_col3, #T_2bffc_row2_col4, #T_2bffc_row3_col2, #T_2bffc_row3_col3, #T_2bffc_row3_col4, #T_2bffc_row4_col2, #T_2bffc_row4_col3, #T_2bffc_row4_col4, #T_2bffc_row5_col2, #T_2bffc_row5_col3, #T_2bffc_row5_col4 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_2bffc_row6_col0 {\n",
|
|||
|
" background-color: #42be71;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_2bffc_row6_col1 {\n",
|
|||
|
" background-color: #65cb5e;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_2bffc_row6_col2 {\n",
|
|||
|
" background-color: #c5407e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_2bffc_row6_col3 {\n",
|
|||
|
" background-color: #b7318a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_2bffc_row6_col4 {\n",
|
|||
|
" background-color: #b6308b;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_2bffc_row7_col0, #T_2bffc_row7_col1 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_2bffc_row7_col2, #T_2bffc_row7_col3, #T_2bffc_row7_col4 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_2bffc\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_2bffc_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_2bffc_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
|||
|
" <th id=\"T_2bffc_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
|||
|
" <th id=\"T_2bffc_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
|||
|
" <th id=\"T_2bffc_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_2bffc_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
|
|||
|
" <td id=\"T_2bffc_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_2bffc_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
|
|||
|
" <td id=\"T_2bffc_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_2bffc_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
|
|||
|
" <td id=\"T_2bffc_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_2bffc_level0_row3\" class=\"row_heading level0 row3\" >naive_bayes</th>\n",
|
|||
|
" <td id=\"T_2bffc_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_2bffc_level0_row4\" class=\"row_heading level0 row4\" >gradient_boosting</th>\n",
|
|||
|
" <td id=\"T_2bffc_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_2bffc_level0_row5\" class=\"row_heading level0 row5\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_2bffc_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_2bffc_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_2bffc_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
|
|||
|
" <td id=\"T_2bffc_row6_col0\" class=\"data row6 col0\" >0.857143</td>\n",
|
|||
|
" <td id=\"T_2bffc_row6_col1\" class=\"data row6 col1\" >0.796296</td>\n",
|
|||
|
" <td id=\"T_2bffc_row6_col2\" class=\"data row6 col2\" >0.923333</td>\n",
|
|||
|
" <td id=\"T_2bffc_row6_col3\" class=\"data row6 col3\" >0.686296</td>\n",
|
|||
|
" <td id=\"T_2bffc_row6_col4\" class=\"data row6 col4\" >0.686296</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_2bffc_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
|
|||
|
" <td id=\"T_2bffc_row7_col0\" class=\"data row7 col0\" >0.636364</td>\n",
|
|||
|
" <td id=\"T_2bffc_row7_col1\" class=\"data row7 col1\" >0.176471</td>\n",
|
|||
|
" <td id=\"T_2bffc_row7_col2\" class=\"data row7 col2\" >0.630556</td>\n",
|
|||
|
" <td id=\"T_2bffc_row7_col3\" class=\"data row7 col3\" >0.037500</td>\n",
|
|||
|
" <td id=\"T_2bffc_row7_col4\" class=\"data row7 col4\" >0.051640</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x135c3bf80b0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 15,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
|||
|
" [\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]\n",
|
|||
|
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" ],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Такой же вывод можно сделать и для следующих метрик: Accuracy, F1, ROC AUC, Cohen's Kappa и MCC. Все модели, кроме KNN и MLP, указывают на хорошо-развитую способность к выделению классов"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 16,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'logistic'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
|
|||
|
"\n",
|
|||
|
"display(best_model)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Вывод данных с ошибкой предсказания для оценки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 17,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'Error items count: 0'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Pregnancies</th>\n",
|
|||
|
" <th>Predicted</th>\n",
|
|||
|
" <th>Glucose</th>\n",
|
|||
|
" <th>BloodPressure</th>\n",
|
|||
|
" <th>SkinThickness</th>\n",
|
|||
|
" <th>Insulin</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>DiabetesPedigreeFunction</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Outcome</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"Empty DataFrame\n",
|
|||
|
"Columns: [Pregnancies, Predicted, Glucose, BloodPressure, SkinThickness, Insulin, BMI, DiabetesPedigreeFunction, Age, Outcome]\n",
|
|||
|
"Index: []"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 17,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"preprocessing_result = pipeline_end.transform(X_test)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"y_pred = class_models[best_model][\"preds\"]\n",
|
|||
|
"\n",
|
|||
|
"error_index = y_test[y_test[\"Outcome\"] != y_pred].index.tolist()\n",
|
|||
|
"display(f\"Error items count: {len(error_index)}\")\n",
|
|||
|
"\n",
|
|||
|
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
|
|||
|
"error_df = X_test.loc[error_index].copy()\n",
|
|||
|
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
|
|||
|
"error_df.sort_index()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Пример использования обученной модели (конвейера) для предсказания"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 22,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Pregnancies</th>\n",
|
|||
|
" <th>Glucose</th>\n",
|
|||
|
" <th>BloodPressure</th>\n",
|
|||
|
" <th>SkinThickness</th>\n",
|
|||
|
" <th>Insulin</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>DiabetesPedigreeFunction</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Outcome</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>190</th>\n",
|
|||
|
" <td>3.0</td>\n",
|
|||
|
" <td>111.0</td>\n",
|
|||
|
" <td>62.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>22.6</td>\n",
|
|||
|
" <td>0.142</td>\n",
|
|||
|
" <td>21.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
|||
|
"190 3.0 111.0 62.0 0.0 0.0 22.6 \n",
|
|||
|
"\n",
|
|||
|
" DiabetesPedigreeFunction Age Outcome \n",
|
|||
|
"190 0.142 21.0 0.0 "
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Pregnancies</th>\n",
|
|||
|
" <th>Glucose</th>\n",
|
|||
|
" <th>BloodPressure</th>\n",
|
|||
|
" <th>SkinThickness</th>\n",
|
|||
|
" <th>Insulin</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>DiabetesPedigreeFunction</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Outcome_1</th>\n",
|
|||
|
" <th>Glucose_Insulin</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>190</th>\n",
|
|||
|
" <td>-0.24739</td>\n",
|
|||
|
" <td>-0.314212</td>\n",
|
|||
|
" <td>-0.404784</td>\n",
|
|||
|
" <td>-1.31138</td>\n",
|
|||
|
" <td>-0.730766</td>\n",
|
|||
|
" <td>-1.193296</td>\n",
|
|||
|
" <td>-1.016353</td>\n",
|
|||
|
" <td>-1.045895</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.429976</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
|||
|
"190 -0.24739 -0.314212 -0.404784 -1.31138 -0.730766 -1.193296 \n",
|
|||
|
"\n",
|
|||
|
" DiabetesPedigreeFunction Age Outcome_1 Glucose_Insulin \n",
|
|||
|
"190 -1.016353 -1.045895 0.0 0.429976 "
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'predicted: 0 (proba: [0.99177252 0.00822748])'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'real: 0'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"model = class_models[best_model][\"pipeline\"]\n",
|
|||
|
"\n",
|
|||
|
"example_id = 190\n",
|
|||
|
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
|
|||
|
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
|
|||
|
"display(test)\n",
|
|||
|
"display(test_preprocessed)\n",
|
|||
|
"result_proba = model.predict_proba(test)[0]\n",
|
|||
|
"result = model.predict(test)[0]\n",
|
|||
|
"real = int(y_test.loc[example_id].values[0])\n",
|
|||
|
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
|
|||
|
"display(f\"real: {real}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Подбор гиперпараметров методом поиска по сетке"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 25,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\tabee\\AIM_PIbd-31_Tabeev_A.P\\.venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
|
|||
|
" _data = np.array(data, dtype=dtype, copy=copy,\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"{'model__criterion': 'gini',\n",
|
|||
|
" 'model__max_depth': 5,\n",
|
|||
|
" 'model__max_features': 'sqrt',\n",
|
|||
|
" 'model__n_estimators': 10}"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 25,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.model_selection import GridSearchCV\n",
|
|||
|
"\n",
|
|||
|
"optimized_model_type = \"random_forest\"\n",
|
|||
|
"\n",
|
|||
|
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
|
|||
|
"\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" \"model__n_estimators\": [10, 50, 100],\n",
|
|||
|
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
|
|||
|
" \"model__max_depth\": [5, 7, 10],\n",
|
|||
|
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"gs_optomizer = GridSearchCV(\n",
|
|||
|
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
|
|||
|
")\n",
|
|||
|
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
|
|||
|
"gs_optomizer.best_params_"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обучение модели с новыми гиперпараметрами"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 26,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.ensemble import RandomForestClassifier\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"\n",
|
|||
|
"numeric_features = X_train.select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
|
|||
|
"\n",
|
|||
|
"random_state = 42\n",
|
|||
|
"# Определение трансформера\n",
|
|||
|
"pipeline_end = ColumnTransformer([\n",
|
|||
|
" ('numeric', StandardScaler(), numeric_features),\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"optimized_model = RandomForestClassifier(\n",
|
|||
|
" random_state=random_state,\n",
|
|||
|
" criterion=\"gini\",\n",
|
|||
|
" max_depth=5,\n",
|
|||
|
" max_features=\"sqrt\",\n",
|
|||
|
" n_estimators=50,\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"result = {}\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели\n",
|
|||
|
"result[\"pipeline\"] = Pipeline([\n",
|
|||
|
" (\"pipeline\", pipeline_end),\n",
|
|||
|
" (\"model\", optimized_model)\n",
|
|||
|
"]).fit(X_train, y_train.values.ravel())\n",
|
|||
|
"\n",
|
|||
|
"# Прогнозирование и расчет метрик\n",
|
|||
|
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
|
|||
|
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
|
|||
|
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
|
|||
|
"\n",
|
|||
|
"# Метрики для оценки модели\n",
|
|||
|
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
|
|||
|
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Формирование данных для оценки старой и новой версии модели"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 27,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
|
|||
|
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
|||
|
" data=class_models[optimized_model_type]\n",
|
|||
|
")\n",
|
|||
|
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
|||
|
" data=result\n",
|
|||
|
")\n",
|
|||
|
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
|
|||
|
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Оценка параметров старой и новой модели"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 28,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_02d85_row0_col0, #T_02d85_row0_col1, #T_02d85_row0_col2, #T_02d85_row0_col3, #T_02d85_row1_col0, #T_02d85_row1_col1, #T_02d85_row1_col2, #T_02d85_row1_col3 {\n",
|
|||
|
" background-color: #440154;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_02d85_row0_col4, #T_02d85_row0_col5, #T_02d85_row0_col6, #T_02d85_row0_col7, #T_02d85_row1_col4, #T_02d85_row1_col5, #T_02d85_row1_col6, #T_02d85_row1_col7 {\n",
|
|||
|
" background-color: #0d0887;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_02d85\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_02d85_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
|||
|
" <th id=\"T_02d85_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
|||
|
" <th id=\"T_02d85_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
|||
|
" <th id=\"T_02d85_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
|||
|
" <th id=\"T_02d85_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
|||
|
" <th id=\"T_02d85_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_02d85_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
|||
|
" <th id=\"T_02d85_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"index_name level0\" >Name</th>\n",
|
|||
|
" <th class=\"blank col0\" > </th>\n",
|
|||
|
" <th class=\"blank col1\" > </th>\n",
|
|||
|
" <th class=\"blank col2\" > </th>\n",
|
|||
|
" <th class=\"blank col3\" > </th>\n",
|
|||
|
" <th class=\"blank col4\" > </th>\n",
|
|||
|
" <th class=\"blank col5\" > </th>\n",
|
|||
|
" <th class=\"blank col6\" > </th>\n",
|
|||
|
" <th class=\"blank col7\" > </th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_02d85_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
|||
|
" <td id=\"T_02d85_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_02d85_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_02d85_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_02d85_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_02d85_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_02d85_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_02d85_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_02d85_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_02d85_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
|||
|
" <td id=\"T_02d85_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_02d85_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_02d85_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_02d85_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_02d85_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_02d85_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_02d85_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_02d85_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x135c73282c0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 28,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"optimized_metrics[\n",
|
|||
|
" [\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" \"Accuracy_train\",\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_train\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"].style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 29,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_34fcd_row0_col0, #T_34fcd_row0_col1, #T_34fcd_row1_col0, #T_34fcd_row1_col1 {\n",
|
|||
|
" background-color: #440154;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_34fcd_row0_col2, #T_34fcd_row0_col3, #T_34fcd_row0_col4, #T_34fcd_row1_col3, #T_34fcd_row1_col4 {\n",
|
|||
|
" background-color: #0d0887;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_34fcd_row1_col2 {\n",
|
|||
|
" background-color: #f0f921;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_34fcd\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_34fcd_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_34fcd_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
|||
|
" <th id=\"T_34fcd_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
|||
|
" <th id=\"T_34fcd_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
|||
|
" <th id=\"T_34fcd_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"index_name level0\" >Name</th>\n",
|
|||
|
" <th class=\"blank col0\" > </th>\n",
|
|||
|
" <th class=\"blank col1\" > </th>\n",
|
|||
|
" <th class=\"blank col2\" > </th>\n",
|
|||
|
" <th class=\"blank col3\" > </th>\n",
|
|||
|
" <th class=\"blank col4\" > </th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_34fcd_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
|||
|
" <td id=\"T_34fcd_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_34fcd_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_34fcd_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_34fcd_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_34fcd_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_34fcd_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
|||
|
" <td id=\"T_34fcd_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_34fcd_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_34fcd_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_34fcd_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_34fcd_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x135c7328590>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 29,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"optimized_metrics[\n",
|
|||
|
" [\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"].style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" ],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 30,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA5IAAAGxCAYAAAAQ1omjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABOw0lEQVR4nO3dfVxUZf7/8fcBBLzhRrzhRhE177A0tVxzK+0Gwy0rs77mLruLpba7ZaZmVlugYom2aS5W6mZJ7s/WbFNTK3fLVss0S1OrlbS8CU3RNhWEFgRmfn+wTjsLKsMczgxzXs/H4zyCczfXTMibz7mucx3D6XQ6BQAAAABALQX5ugEAAAAAgIaFQhIAAAAA4BEKSQAAAACARygkAQAAAAAeoZAEAAAAAHiEQhIAAAAA4BEKSQAAAACARygkAQAAAAAeoZAEAAAAAHiEQhIAAAAA4BEKSQBAQHv//fd18803KyEhQYZhaNWqVW7bnU6nMjMzFR8fr8aNGyslJUVfffWV2z4nTpxQWlqaIiMjFR0drVGjRqm4uNjCdwEAQBV/yTUKSQBAQCspKdGll16q5557rsbtTz31lHJycrRgwQJt3bpVTZs2VWpqqkpLS137pKWl6Z///KfeeecdrV27Vu+//77uueceq94CAAAu/pJrhtPpdHr1TgAAaCAMw9DKlSs1dOhQSVVXbRMSEvTggw9q0qRJkqTCwkLFxsYqNzdXI0aMUF5enrp3765PPvlEl19+uSRp3bp1uvHGG3X48GElJCT46u0AAGzOl7kWUi/vCACA/1JaWqozZ86Ydj6n0ynDMNzWhYWFKSwszKPzHDhwQAUFBUpJSXGti4qKUr9+/bRlyxaNGDFCW7ZsUXR0tCtsJSklJUVBQUHaunWrbrvtNu/eDACgwSHXKCQBAPWstLRUHZKaqeB4pWnnbNasWbV7OaZMmaKpU6d6dJ6CggJJUmxsrNv62NhY17aCggK1bt3abXtISIhiYmJc+wAA7INc+88xHrUMAAAPnTlzRgXHK3Vge5IiI7y/Nb/otEMdLvtGhw4dUmRkpGu9p1dtAQCoC3KtCoUkAMASkRFBpgSu63yRkW6BWxdxcXGSpGPHjik+Pt61/tixY+rVq5drn+PHj7sdV1FRoRMnTriOBwDYj91zjVlbAQCWqHQ6TFvM0qFDB8XFxWn9+vWudUVFRdq6dav69+8vSerfv79OnTql7du3u/Z577335HA41K9fP9PaAgBoWOyea/RIAgAs4ZBTDnk/Ubin5yguLtbXX3/t+v7AgQPauXOnYmJi1K5dO40fP15PPPGEOnfurA4dOigjI0MJCQmuGfCSk5M1ePBgjRkzRgsWLFB5ebnGjh2rESNGMGMrANiY3XONQhIAENC2bduma6+91vX9xIkTJUnp6enKzc3V5MmTVVJSonvuuUenTp3SVVddpXXr1ik8PNx1zNKlSzV27Fhdf/31CgoK0u23366cnBzL3wsAAP6SazxHEgBQr4qKihQVFaUje9qaNilBQtfDKiws9PpeEgAAPEWuVaFHEgBgiUqnU5UmXLs04xwAAHjL7rnGZDsAAAAAAI/QIwkAsISvJiUAAKA+2D3XKCQBAJZwyKlKGwcuACCw2D3XGNoKAAAAAPAIPZIAAEvYfQgQACCw2D3X6JEEAAAAAHiEHkkAgCXsPk06ACCw2D3XKCQBAJZw/Gcx4zwAAPia3XONoa0AAAAAAI/QIwkAsESlSdOkm3EOAAC8Zfdco5AEAFii0lm1mHEeAAB8ze65xtBWAAAAAIBH6JEEAFjC7pMSAAACi91zjUISAGAJhwxVyjDlPAAA+Jrdc42hrQAAAAAAj9AjCQCwhMNZtZhxHgAAfM3uuUaPJAAAAADAI/RIAgAsUWnSvSRmnAMAAG/ZPdcoJAEAlrB74AIAAovdc42hrQAAAAAAj9AjCQCwhMNpyOE0YZp0E84BAIC37J5rFJIAAEvYfQgQACCw2D3XGNoKAAAAAPAIPZIAAEtUKkiVJly/rDShLQAAeMvuuUYhCQCwhNOke0mcDfReEgBAYLF7rjG0FQAAAADgEXokAQCWsPukBACAwGL3XKOQBABYotIZpEqnCfeSOE1oDAAAXrJ7rjG0FQAAAADgEXokAQCWcMiQw4Trlw410Eu3AICAYvdco0cSAAAAAOAReiQBAJaw+6QEAIDAYvdco5AEAFjCvEkJGuYQIABAYLF7rjG0FQAAAADgEXokAQCWqJqUwPvhO2acAwAAb9k91ygkAQCWcChIlTae3Q4AEFjsnmsMbQUAAAAAeIQeSQCAJew+KQEAILDYPdcoJAEAlnAoyNYPbgYABBa75xpDWwEAAAAAHqFHEgBgiUqnoUqnCQ9uNuEcAAB4y+65Ro8kAAAAAMAj9EgCACxRadI06ZUN9F4SAEBgsXuuUUgCACzhcAbJYcLsdo4GOrsdACCw2D3XGNoKAAAAAPAIPZIAAEvYfQgQACCw2D3XKCQBAJZwyJyZ6RzeNwUAAK/ZPdcY2goAAAAA8Ag9kgAASzgUJIcJ1y/NOAcAAN6ye65RSAIALFHpDFKlCbPbmXEOAAC8Zfdca5itBgAAAAD4DD2SAABLOGTIITMmJfD+HAAAeMvuuUYhCQCwhN2HAAEAAovdc61hthoAAAAA4DP0SAIALGHeg5u5BgoA8D2751rDbDUAAAAAwGfokbQRh8OhI0eOKCIiQobRMG/qBWAtp9Op06dPKyEhQUFB3l17dDgNOZwmTEpgwjkQGMg1AJ4i18xDIWkjR44cUWJioq+bAaABOnTokNq2bevVORwmDQFqqA9uhvnINQB1Ra55j0LSRiIiIiRJ33zaXpHNGuYPLOrPbV16+LoJ8EMVKtcmveX6/QH4E3IN50OuoSbkmnkoJG3k7LCfyGZBiowgcOEuxGjk6ybAHzmr/mPGsEGHM0gOE6Y4N+McCAzkGs6HXEONyDXTUEgCACxRKUOVJjx02YxzAADgLbvnWsMsfwEAAAAAPkOPJADAEnYfAgQACCx2zzUKSQCAJSplzvCdSu+bAgCA1+yeaw2z/AUAAAAA+Aw9kgAAS9h9CBAAILDYPdcaZqsBAAAAAD5DjyQAwBKVziBVmnDV1YxzAADgLbvnWsNsNQCgwXHKkMOExenBxAaVlZXKyMhQhw4d1LhxY1100UWaPn26nE7nj+1yOpWZman4+Hg1btxYKSkp+uqrr+rjIwAABBC75xqFJAAgYM2aNUvz58/Xs88+q7y8PM2aNUtPPfWU5s2b59rnqaeeUk5OjhYsWKCtW7eqadOmSk1NVWlpqQ9bDgBAdf6UawxtBQBYwhdDgDZv3qxbb71VN910kySpffv2+stf/qKPP/5YUtVV27lz5+rxxx/XrbfeKklasmSJYmNjtWrVKo0YMcLr9gIAApPdc40eSQCAJRxOw7RFkoqKityWsrKyaq/505/+VOvXr9fevXslSbt27dKmTZv0s5/9TJJ04MABFRQUKCUlxXVMVFSU+vXrpy1btljwqQAAGiq75xo9kgCABikxMdHt+ylTpmjq1Klu6x555BEVFRWpW7duCg4OVmVlpZ588kmlpaVJkgoKCiRJsbGxbsfFxsa6tgEAYIWGlmsUkgAAS1QqSJUmDIQ5e45Dhw4pMjLStT4sLKzavsuXL9fSpUv1yiuv6OKLL9bOnTs1fvx4JSQkKD093eu2AADsy+65RiEJALDEfw/f8fY8khQZGekWuDV56KGH9Mgjj7juCenRo4e++eYbZWdnKz09XXFxcZKkY8eOKT4+3nXcsWPH1KtXL6/bCgAIXHbPNe6RBAAErB9++EFBQe5RFxwcLIfDIUnq0KGD4uLitH79etf2oqIibd26Vf3797e0rQAAXIg/5Ro9kgAASzgUJIcJ1y89OcfNN9+sJ598Uu3atdPFF1+sHTt2aM6cObr77rslSYZhaPz48XriiSfUuXNndejQQRkZGUpISNDQoUO9bisAIHDZPdcoJAEAlqh0Gqo0YQiQJ+eYN2+eMjIydO+99+r48eNKSEjQb37zG2VmZrr2mTx5skpKSnTPPffo1KlTuuq
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1000x400 with 4 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"for index in range(0, len(optimized_metrics)):\n",
|
|||
|
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[\"No Diabetes\", \"Diabetes\"]\n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
"\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Значение 100 означает количество верно классифицированных объектов, которые относятся к классу \"No Diabetes\". Можно сделать вывод, что модель отлично идентифицирует объекты этого класса.\n",
|
|||
|
"\n",
|
|||
|
"Значения 0 означают количество верно классифицированных объектов, которые относятся к классу \"Diabetes\". Можно сделать вывод, что это не высокая точность модели в определении объектов данного класса"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Определение достижимого уровня качества модели для второй задачи (задача регрессии)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 31,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"(700, 8)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Pregnancies</th>\n",
|
|||
|
" <th>Glucose</th>\n",
|
|||
|
" <th>BloodPressure</th>\n",
|
|||
|
" <th>SkinThickness</th>\n",
|
|||
|
" <th>Insulin</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>DiabetesPedigreeFunction</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>6</td>\n",
|
|||
|
" <td>98</td>\n",
|
|||
|
" <td>58</td>\n",
|
|||
|
" <td>33</td>\n",
|
|||
|
" <td>190</td>\n",
|
|||
|
" <td>34.0</td>\n",
|
|||
|
" <td>0.430</td>\n",
|
|||
|
" <td>43</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>112</td>\n",
|
|||
|
" <td>75</td>\n",
|
|||
|
" <td>32</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>35.7</td>\n",
|
|||
|
" <td>0.148</td>\n",
|
|||
|
" <td>21</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>108</td>\n",
|
|||
|
" <td>64</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>30.8</td>\n",
|
|||
|
" <td>0.158</td>\n",
|
|||
|
" <td>21</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>8</td>\n",
|
|||
|
" <td>107</td>\n",
|
|||
|
" <td>80</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>24.6</td>\n",
|
|||
|
" <td>0.856</td>\n",
|
|||
|
" <td>34</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>7</td>\n",
|
|||
|
" <td>136</td>\n",
|
|||
|
" <td>90</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>29.9</td>\n",
|
|||
|
" <td>0.210</td>\n",
|
|||
|
" <td>50</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>695</th>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>105</td>\n",
|
|||
|
" <td>80</td>\n",
|
|||
|
" <td>45</td>\n",
|
|||
|
" <td>191</td>\n",
|
|||
|
" <td>33.7</td>\n",
|
|||
|
" <td>0.711</td>\n",
|
|||
|
" <td>29</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>696</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>126</td>\n",
|
|||
|
" <td>56</td>\n",
|
|||
|
" <td>29</td>\n",
|
|||
|
" <td>152</td>\n",
|
|||
|
" <td>28.7</td>\n",
|
|||
|
" <td>0.801</td>\n",
|
|||
|
" <td>21</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>697</th>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>95</td>\n",
|
|||
|
" <td>54</td>\n",
|
|||
|
" <td>14</td>\n",
|
|||
|
" <td>88</td>\n",
|
|||
|
" <td>26.1</td>\n",
|
|||
|
" <td>0.748</td>\n",
|
|||
|
" <td>22</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>698</th>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>100</td>\n",
|
|||
|
" <td>68</td>\n",
|
|||
|
" <td>23</td>\n",
|
|||
|
" <td>81</td>\n",
|
|||
|
" <td>31.6</td>\n",
|
|||
|
" <td>0.949</td>\n",
|
|||
|
" <td>28</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>699</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>85</td>\n",
|
|||
|
" <td>66</td>\n",
|
|||
|
" <td>29</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>26.6</td>\n",
|
|||
|
" <td>0.351</td>\n",
|
|||
|
" <td>31</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>700 rows × 8 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
|||
|
"0 6 98 58 33 190 34.0 \n",
|
|||
|
"1 2 112 75 32 0 35.7 \n",
|
|||
|
"2 2 108 64 0 0 30.8 \n",
|
|||
|
"3 8 107 80 0 0 24.6 \n",
|
|||
|
"4 7 136 90 0 0 29.9 \n",
|
|||
|
".. ... ... ... ... ... ... \n",
|
|||
|
"695 2 105 80 45 191 33.7 \n",
|
|||
|
"696 1 126 56 29 152 28.7 \n",
|
|||
|
"697 2 95 54 14 88 26.1 \n",
|
|||
|
"698 3 100 68 23 81 31.6 \n",
|
|||
|
"699 1 85 66 29 0 26.6 \n",
|
|||
|
"\n",
|
|||
|
" DiabetesPedigreeFunction Age \n",
|
|||
|
"0 0.430 43 \n",
|
|||
|
"1 0.148 21 \n",
|
|||
|
"2 0.158 21 \n",
|
|||
|
"3 0.856 34 \n",
|
|||
|
"4 0.210 50 \n",
|
|||
|
".. ... ... \n",
|
|||
|
"695 0.711 29 \n",
|
|||
|
"696 0.801 21 \n",
|
|||
|
"697 0.748 22 \n",
|
|||
|
"698 0.949 28 \n",
|
|||
|
"699 0.351 31 \n",
|
|||
|
"\n",
|
|||
|
"[700 rows x 8 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import seaborn as sns\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"from sklearn import set_config\n",
|
|||
|
"\n",
|
|||
|
"random_state = 42\n",
|
|||
|
"set_config(transform_output=\"pandas\")\n",
|
|||
|
"\n",
|
|||
|
"df = pd.read_csv(\"data/diabetes.csv\")\n",
|
|||
|
"\n",
|
|||
|
"df = df.drop(columns=[\"Outcome\"])\n",
|
|||
|
"\n",
|
|||
|
"df = df.sample(n=700, random_state=random_state).reset_index(drop=True)\n",
|
|||
|
"\n",
|
|||
|
"print(df.shape) \n",
|
|||
|
"display(df)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 32,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
|||
|
"0 6 148 72 35 0 33.6 \n",
|
|||
|
"1 1 85 66 29 0 26.6 \n",
|
|||
|
"2 8 183 64 0 0 23.3 \n",
|
|||
|
"3 1 89 66 23 94 28.1 \n",
|
|||
|
"4 0 137 40 35 168 43.1 \n",
|
|||
|
"\n",
|
|||
|
" DiabetesPedigreeFunction Age diabetes_risk_index \n",
|
|||
|
"0 0.627 50 71.68 \n",
|
|||
|
"1 0.351 31 46.28 \n",
|
|||
|
"2 0.672 32 74.69 \n",
|
|||
|
"3 0.167 21 55.33 \n",
|
|||
|
"4 2.288 33 81.43 \n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"\n",
|
|||
|
"df = pd.read_csv(\"data/diabetes.csv\")\n",
|
|||
|
"\n",
|
|||
|
"required_columns = [\"Pregnancies\", \"Glucose\", \"BloodPressure\", \"SkinThickness\", \"Insulin\", \"BMI\", \"DiabetesPedigreeFunction\", \"Age\"]\n",
|
|||
|
"missing_columns = [col for col in required_columns if col not in df.columns]\n",
|
|||
|
"if missing_columns:\n",
|
|||
|
" raise ValueError(f\"Отсутствуют столбцы: {missing_columns}\")\n",
|
|||
|
"\n",
|
|||
|
"df[\"diabetes_risk_index\"] = (\n",
|
|||
|
" df[\"Glucose\"] * 0.3 \n",
|
|||
|
" + df[\"BMI\"] * 0.3 \n",
|
|||
|
" + df[\"Age\"] * 0.2 \n",
|
|||
|
" + df[\"BloodPressure\"] * 0.1 \n",
|
|||
|
" + df[\"Insulin\"] * 0.1 \n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"print(df[[\"Pregnancies\", \"Glucose\", \"BloodPressure\", \"SkinThickness\", \"Insulin\", \"BMI\", \"DiabetesPedigreeFunction\", \"Age\", \"diabetes_risk_index\"]].head())"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи регрессии"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 33,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Pregnancies</th>\n",
|
|||
|
" <th>Glucose</th>\n",
|
|||
|
" <th>BloodPressure</th>\n",
|
|||
|
" <th>SkinThickness</th>\n",
|
|||
|
" <th>Insulin</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>DiabetesPedigreeFunction</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Outcome</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>60</th>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>84</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.304</td>\n",
|
|||
|
" <td>21</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>618</th>\n",
|
|||
|
" <td>9</td>\n",
|
|||
|
" <td>112</td>\n",
|
|||
|
" <td>82</td>\n",
|
|||
|
" <td>24</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>28.2</td>\n",
|
|||
|
" <td>1.282</td>\n",
|
|||
|
" <td>50</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>346</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>139</td>\n",
|
|||
|
" <td>46</td>\n",
|
|||
|
" <td>19</td>\n",
|
|||
|
" <td>83</td>\n",
|
|||
|
" <td>28.7</td>\n",
|
|||
|
" <td>0.654</td>\n",
|
|||
|
" <td>22</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>294</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>161</td>\n",
|
|||
|
" <td>50</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>21.9</td>\n",
|
|||
|
" <td>0.254</td>\n",
|
|||
|
" <td>65</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>231</th>\n",
|
|||
|
" <td>6</td>\n",
|
|||
|
" <td>134</td>\n",
|
|||
|
" <td>80</td>\n",
|
|||
|
" <td>37</td>\n",
|
|||
|
" <td>370</td>\n",
|
|||
|
" <td>46.2</td>\n",
|
|||
|
" <td>0.238</td>\n",
|
|||
|
" <td>46</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
|||
|
"60 2 84 0 0 0 0.0 \n",
|
|||
|
"618 9 112 82 24 0 28.2 \n",
|
|||
|
"346 1 139 46 19 83 28.7 \n",
|
|||
|
"294 0 161 50 0 0 21.9 \n",
|
|||
|
"231 6 134 80 37 370 46.2 \n",
|
|||
|
"\n",
|
|||
|
" DiabetesPedigreeFunction Age Outcome \n",
|
|||
|
"60 0.304 21 0 \n",
|
|||
|
"618 1.282 50 1 \n",
|
|||
|
"346 0.654 22 0 \n",
|
|||
|
"294 0.254 65 0 \n",
|
|||
|
"231 0.238 46 1 "
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>diabetes_risk_index</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>60</th>\n",
|
|||
|
" <td>29.40</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>618</th>\n",
|
|||
|
" <td>60.26</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>346</th>\n",
|
|||
|
" <td>67.61</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>294</th>\n",
|
|||
|
" <td>72.87</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>231</th>\n",
|
|||
|
" <td>108.26</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" diabetes_risk_index\n",
|
|||
|
"60 29.40\n",
|
|||
|
"618 60.26\n",
|
|||
|
"346 67.61\n",
|
|||
|
"294 72.87\n",
|
|||
|
"231 108.26"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Pregnancies</th>\n",
|
|||
|
" <th>Glucose</th>\n",
|
|||
|
" <th>BloodPressure</th>\n",
|
|||
|
" <th>SkinThickness</th>\n",
|
|||
|
" <th>Insulin</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>DiabetesPedigreeFunction</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Outcome</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>668</th>\n",
|
|||
|
" <td>6</td>\n",
|
|||
|
" <td>98</td>\n",
|
|||
|
" <td>58</td>\n",
|
|||
|
" <td>33</td>\n",
|
|||
|
" <td>190</td>\n",
|
|||
|
" <td>34.0</td>\n",
|
|||
|
" <td>0.430</td>\n",
|
|||
|
" <td>43</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>324</th>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>112</td>\n",
|
|||
|
" <td>75</td>\n",
|
|||
|
" <td>32</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>35.7</td>\n",
|
|||
|
" <td>0.148</td>\n",
|
|||
|
" <td>21</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>624</th>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>108</td>\n",
|
|||
|
" <td>64</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>30.8</td>\n",
|
|||
|
" <td>0.158</td>\n",
|
|||
|
" <td>21</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>690</th>\n",
|
|||
|
" <td>8</td>\n",
|
|||
|
" <td>107</td>\n",
|
|||
|
" <td>80</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>24.6</td>\n",
|
|||
|
" <td>0.856</td>\n",
|
|||
|
" <td>34</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>473</th>\n",
|
|||
|
" <td>7</td>\n",
|
|||
|
" <td>136</td>\n",
|
|||
|
" <td>90</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>29.9</td>\n",
|
|||
|
" <td>0.210</td>\n",
|
|||
|
" <td>50</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
|
|||
|
"668 6 98 58 33 190 34.0 \n",
|
|||
|
"324 2 112 75 32 0 35.7 \n",
|
|||
|
"624 2 108 64 0 0 30.8 \n",
|
|||
|
"690 8 107 80 0 0 24.6 \n",
|
|||
|
"473 7 136 90 0 0 29.9 \n",
|
|||
|
"\n",
|
|||
|
" DiabetesPedigreeFunction Age Outcome \n",
|
|||
|
"668 0.430 43 0 \n",
|
|||
|
"324 0.148 21 0 \n",
|
|||
|
"624 0.158 21 0 \n",
|
|||
|
"690 0.856 34 0 \n",
|
|||
|
"473 0.210 50 0 "
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>diabetes_risk_index</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>668</th>\n",
|
|||
|
" <td>73.00</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>324</th>\n",
|
|||
|
" <td>56.01</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>624</th>\n",
|
|||
|
" <td>52.24</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>690</th>\n",
|
|||
|
" <td>54.28</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>473</th>\n",
|
|||
|
" <td>68.77</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" diabetes_risk_index\n",
|
|||
|
"668 73.00\n",
|
|||
|
"324 56.01\n",
|
|||
|
"624 52.24\n",
|
|||
|
"690 54.28\n",
|
|||
|
"473 68.77"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from typing import Tuple\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"from pandas import DataFrame\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"\n",
|
|||
|
"def split_into_train_test(\n",
|
|||
|
" df_input: DataFrame,\n",
|
|||
|
" target_colname: str = \"diabetes_risk_index\",\n",
|
|||
|
" frac_train: float = 0.8,\n",
|
|||
|
" random_state: int = None,\n",
|
|||
|
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
|||
|
" if not (0 < frac_train < 1):\n",
|
|||
|
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
|
|||
|
" \n",
|
|||
|
" if target_colname not in df_input.columns:\n",
|
|||
|
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
|
|||
|
" \n",
|
|||
|
" X = df_input.drop(columns=[target_colname])\n",
|
|||
|
" y = df_input[[target_colname]]\n",
|
|||
|
"\n",
|
|||
|
" X_train, X_test, y_train, y_test = train_test_split(\n",
|
|||
|
" X, y,\n",
|
|||
|
" test_size=(1.0 - frac_train),\n",
|
|||
|
" random_state=random_state\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" return X_train, X_test, y_train, y_test\n",
|
|||
|
"\n",
|
|||
|
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
|
|||
|
" df, \n",
|
|||
|
" target_colname=\"diabetes_risk_index\", \n",
|
|||
|
" frac_train=0.8, \n",
|
|||
|
" random_state=42\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_train\", X_train.head())\n",
|
|||
|
"display(\"y_train\", y_train.head())\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_test\", X_test.head())\n",
|
|||
|
"display(\"y_test\", y_test.head())"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Определение перечня алгоритмов решения задачи аппроксимации (регрессии)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 34,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.pipeline import make_pipeline\n",
|
|||
|
"from sklearn.preprocessing import PolynomialFeatures, StandardScaler\n",
|
|||
|
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
|
|||
|
"\n",
|
|||
|
"random_state = 42\n",
|
|||
|
"\n",
|
|||
|
"models = {\n",
|
|||
|
" # Линейная регрессия\n",
|
|||
|
" \"linear\": {\n",
|
|||
|
" \"model\": linear_model.LinearRegression(n_jobs=-1)\n",
|
|||
|
" },\n",
|
|||
|
" # Полиномиальная регрессия степени 2\n",
|
|||
|
" \"linear_poly\": {\n",
|
|||
|
" \"model\": make_pipeline(\n",
|
|||
|
" PolynomialFeatures(degree=2),\n",
|
|||
|
" StandardScaler(),\n",
|
|||
|
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" # Полиномиальная регрессия с взаимодействиями\n",
|
|||
|
" \"linear_interact\": {\n",
|
|||
|
" \"model\": make_pipeline(\n",
|
|||
|
" PolynomialFeatures(interaction_only=True),\n",
|
|||
|
" StandardScaler(),\n",
|
|||
|
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" # Ridge-регрессия\n",
|
|||
|
" \"ridge\": {\n",
|
|||
|
" \"model\": make_pipeline(\n",
|
|||
|
" StandardScaler(),\n",
|
|||
|
" linear_model.RidgeCV()\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" # Регрессия на основе дерева решений\n",
|
|||
|
" \"decision_tree\": {\n",
|
|||
|
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
|
|||
|
" },\n",
|
|||
|
" # Метод ближайших соседей (kNN)\n",
|
|||
|
" \"knn\": {\n",
|
|||
|
" \"model\": make_pipeline(\n",
|
|||
|
" StandardScaler(),\n",
|
|||
|
" neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" # Случайный лес (Random Forest)\n",
|
|||
|
" \"random_forest\": {\n",
|
|||
|
" \"model\": ensemble.RandomForestRegressor(\n",
|
|||
|
" max_depth=7, random_state=random_state, n_jobs=-1\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" # Нейронная сеть (MLPRegressor)\n",
|
|||
|
" \"mlp\": {\n",
|
|||
|
" \"model\": make_pipeline(\n",
|
|||
|
" StandardScaler(),\n",
|
|||
|
" neural_network.MLPRegressor(\n",
|
|||
|
" activation=\"tanh\",\n",
|
|||
|
" hidden_layer_sizes=(3,),\n",
|
|||
|
" max_iter=500,\n",
|
|||
|
" early_stopping=True,\n",
|
|||
|
" random_state=random_state,\n",
|
|||
|
" )\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
"}"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обучение и оценка моделей с помощью различных алгоритмов"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 35,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: linear\n",
|
|||
|
"Model: linear_poly\n",
|
|||
|
"Model: linear_interact\n",
|
|||
|
"Model: ridge\n",
|
|||
|
"Model: decision_tree\n",
|
|||
|
"Model: knn\n",
|
|||
|
"Model: random_forest\n",
|
|||
|
"Model: mlp\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\tabee\\AIM_PIbd-31_Tabeev_A.P\\.venv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:690: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (500) reached and the optimization hasn't converged yet.\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import math\n",
|
|||
|
"from pandas import DataFrame\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"\n",
|
|||
|
"for model_name in models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
"\n",
|
|||
|
" fitted_model = models[model_name][\"model\"].fit(\n",
|
|||
|
" X_train.values, y_train.values.ravel()\n",
|
|||
|
" )\n",
|
|||
|
" y_train_pred = fitted_model.predict(X_train.values)\n",
|
|||
|
" y_test_pred = fitted_model.predict(X_test.values)\n",
|
|||
|
" models[model_name][\"fitted\"] = fitted_model\n",
|
|||
|
" models[model_name][\"train_preds\"] = y_train_pred\n",
|
|||
|
" models[model_name][\"preds\"] = y_test_pred\n",
|
|||
|
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
|
|||
|
" metrics.mean_squared_error(y_train, y_train_pred)\n",
|
|||
|
" )\n",
|
|||
|
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
|
|||
|
" metrics.mean_squared_error(y_test, y_test_pred)\n",
|
|||
|
" )\n",
|
|||
|
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
|
|||
|
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
|
|||
|
" )\n",
|
|||
|
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Вывод результатов оценки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 36,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_4f71e_row0_col0, #T_4f71e_row0_col1, #T_4f71e_row1_col0, #T_4f71e_row1_col1 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row0_col2, #T_4f71e_row1_col2, #T_4f71e_row7_col3 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row0_col3, #T_4f71e_row1_col3, #T_4f71e_row2_col3, #T_4f71e_row3_col3, #T_4f71e_row7_col2 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row2_col0 {\n",
|
|||
|
" background-color: #25838e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row2_col1 {\n",
|
|||
|
" background-color: #25858e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row2_col2 {\n",
|
|||
|
" background-color: #7100a8;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row3_col0 {\n",
|
|||
|
" background-color: #25848e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row3_col1 {\n",
|
|||
|
" background-color: #24878e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row3_col2 {\n",
|
|||
|
" background-color: #7501a8;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row4_col0 {\n",
|
|||
|
" background-color: #23888e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row4_col1 {\n",
|
|||
|
" background-color: #23898e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row4_col2 {\n",
|
|||
|
" background-color: #7a02a8;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row4_col3, #T_4f71e_row6_col2 {\n",
|
|||
|
" background-color: #d9586a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row5_col0, #T_4f71e_row5_col1 {\n",
|
|||
|
" background-color: #89d548;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row5_col2 {\n",
|
|||
|
" background-color: #d24f71;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row5_col3 {\n",
|
|||
|
" background-color: #7201a8;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row6_col0, #T_4f71e_row7_col0, #T_4f71e_row7_col1 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row6_col1 {\n",
|
|||
|
" background-color: #a2da37;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4f71e_row6_col3 {\n",
|
|||
|
" background-color: #5302a3;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_4f71e\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_4f71e_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
|
|||
|
" <th id=\"T_4f71e_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
|
|||
|
" <th id=\"T_4f71e_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
|
|||
|
" <th id=\"T_4f71e_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_4f71e_level0_row0\" class=\"row_heading level0 row0\" >linear</th>\n",
|
|||
|
" <td id=\"T_4f71e_row0_col0\" class=\"data row0 col0\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_4f71e_row0_col1\" class=\"data row0 col1\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_4f71e_row0_col2\" class=\"data row0 col2\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_4f71e_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_4f71e_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
|
|||
|
" <td id=\"T_4f71e_row1_col0\" class=\"data row1 col0\" >0.002409</td>\n",
|
|||
|
" <td id=\"T_4f71e_row1_col1\" class=\"data row1 col1\" >0.002181</td>\n",
|
|||
|
" <td id=\"T_4f71e_row1_col2\" class=\"data row1 col2\" >0.040847</td>\n",
|
|||
|
" <td id=\"T_4f71e_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_4f71e_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_4f71e_row2_col0\" class=\"data row2 col0\" >1.856691</td>\n",
|
|||
|
" <td id=\"T_4f71e_row2_col1\" class=\"data row2 col1\" >3.387390</td>\n",
|
|||
|
" <td id=\"T_4f71e_row2_col2\" class=\"data row2 col2\" >1.612410</td>\n",
|
|||
|
" <td id=\"T_4f71e_row2_col3\" class=\"data row2 col3\" >0.967338</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_4f71e_level0_row3\" class=\"row_heading level0 row3\" >decision_tree</th>\n",
|
|||
|
" <td id=\"T_4f71e_row3_col0\" class=\"data row3 col0\" >2.409987</td>\n",
|
|||
|
" <td id=\"T_4f71e_row3_col1\" class=\"data row3 col1\" >4.281378</td>\n",
|
|||
|
" <td id=\"T_4f71e_row3_col2\" class=\"data row3 col2\" >1.847379</td>\n",
|
|||
|
" <td id=\"T_4f71e_row3_col3\" class=\"data row3 col3\" >0.947823</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_4f71e_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
|
|||
|
" <td id=\"T_4f71e_row4_col0\" class=\"data row4 col0\" >5.170083</td>\n",
|
|||
|
" <td id=\"T_4f71e_row4_col1\" class=\"data row4 col1\" >5.824576</td>\n",
|
|||
|
" <td id=\"T_4f71e_row4_col2\" class=\"data row4 col2\" >2.073599</td>\n",
|
|||
|
" <td id=\"T_4f71e_row4_col3\" class=\"data row4 col3\" >0.903431</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_4f71e_level0_row5\" class=\"row_heading level0 row5\" >mlp</th>\n",
|
|||
|
" <td id=\"T_4f71e_row5_col0\" class=\"data row5 col0\" >60.594447</td>\n",
|
|||
|
" <td id=\"T_4f71e_row5_col1\" class=\"data row5 col1\" >59.994001</td>\n",
|
|||
|
" <td id=\"T_4f71e_row5_col2\" class=\"data row5 col2\" >7.553087</td>\n",
|
|||
|
" <td id=\"T_4f71e_row5_col3\" class=\"data row5 col3\" >-9.245313</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_4f71e_level0_row6\" class=\"row_heading level0 row6\" >linear_poly</th>\n",
|
|||
|
" <td id=\"T_4f71e_row6_col0\" class=\"data row6 col0\" >67.720820</td>\n",
|
|||
|
" <td id=\"T_4f71e_row6_col1\" class=\"data row6 col1\" >66.518485</td>\n",
|
|||
|
" <td id=\"T_4f71e_row6_col2\" class=\"data row6 col2\" >8.144355</td>\n",
|
|||
|
" <td id=\"T_4f71e_row6_col3\" class=\"data row6 col3\" >-11.594887</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_4f71e_level0_row7\" class=\"row_heading level0 row7\" >linear_interact</th>\n",
|
|||
|
" <td id=\"T_4f71e_row7_col0\" class=\"data row7 col0\" >67.518306</td>\n",
|
|||
|
" <td id=\"T_4f71e_row7_col1\" class=\"data row7 col1\" >67.518306</td>\n",
|
|||
|
" <td id=\"T_4f71e_row7_col2\" class=\"data row7 col2\" >8.216952</td>\n",
|
|||
|
" <td id=\"T_4f71e_row7_col3\" class=\"data row7 col3\" >-11.976353</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x135c8320bc0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 36,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
|
|||
|
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
|
|||
|
"]\n",
|
|||
|
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
|
|||
|
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
|
|||
|
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Вывод реального и \"спрогнозированного\" результата для обучающей и тестовой выборок\n",
|
|||
|
"\n",
|
|||
|
"Получение лучшей модели"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 37,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'linear'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
|
|||
|
"\n",
|
|||
|
"display(best_model)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Подбор гиперпараметров методом поиска по сетке"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 40,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Fitting 3 folds for each of 27 candidates, totalling 81 fits\n",
|
|||
|
"Лучшие параметры: {'max_depth': 20, 'min_samples_split': 5, 'n_estimators': 150}\n",
|
|||
|
"Лучший результат (MSE): 11.258082426680579\n",
|
|||
|
"Ошибка на тестовой выборке (MSE): 9.1015\n",
|
|||
|
"Коэффициент детерминации (R²): 0.9741\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"\n",
|
|||
|
"# Удаление пропущенных значений (если есть)\n",
|
|||
|
"df.dropna(inplace=True)\n",
|
|||
|
"\n",
|
|||
|
"X = df[[\"Pregnancies\", \"Glucose\", \"BloodPressure\", \"SkinThickness\", \"Insulin\", \n",
|
|||
|
" \"BMI\", \"DiabetesPedigreeFunction\", \"Age\"]] \n",
|
|||
|
"y = df[\"diabetes_risk_index\"] \n",
|
|||
|
"\n",
|
|||
|
"X_train, X_test, y_train, y_test = train_test_split(\n",
|
|||
|
" X, y, test_size=0.2, random_state=42\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"model = RandomForestRegressor(random_state=42)\n",
|
|||
|
"\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" 'n_estimators': [50, 100, 150], \n",
|
|||
|
" 'max_depth': [10, 20, 30], \n",
|
|||
|
" 'min_samples_split': [5, 10, 15] \n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"grid_search = GridSearchCV(\n",
|
|||
|
" estimator=model,\n",
|
|||
|
" param_grid=param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', \n",
|
|||
|
" cv=3, \n",
|
|||
|
" n_jobs=-1, \n",
|
|||
|
" verbose=2 \n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"grid_search.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
|
|||
|
"print(\"Лучший результат (MSE):\", -grid_search.best_score_)\n",
|
|||
|
"\n",
|
|||
|
"best_model = grid_search.best_estimator_\n",
|
|||
|
"y_pred = best_model.predict(X_test)\n",
|
|||
|
"\n",
|
|||
|
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
|
|||
|
"r2 = metrics.r2_score(y_test, y_pred)\n",
|
|||
|
"\n",
|
|||
|
"print(f\"Ошибка на тестовой выборке (MSE): {mse:.4f}\")\n",
|
|||
|
"print(f\"Коэффициент детерминации (R²): {r2:.4f}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обучение модели с новыми гиперпараметрами и сравнение новых и старых данных"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 41,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Fitting 3 folds for each of 8 candidates, totalling 24 fits\n",
|
|||
|
"Старые параметры: {'max_depth': 20, 'min_samples_split': 5, 'n_estimators': 100}\n",
|
|||
|
"Лучший результат (MSE) на старых параметрах: 11.682596479300793\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"Новые параметры: {'max_depth': 20, 'min_samples_split': 10, 'n_estimators': 100}\n",
|
|||
|
"Лучший результат (MSE) на новых параметрах: 18.55198575928597\n",
|
|||
|
"Среднеквадратическая ошибка (MSE) на тестовых данных: 10.739599657760765\n",
|
|||
|
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 3.27713284103052\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor\n",
|
|||
|
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"old_param_grid = {\n",
|
|||
|
" 'n_estimators': [50, 100],\n",
|
|||
|
" 'max_depth': [ 10, 20],\n",
|
|||
|
" 'min_samples_split': [5, 10]\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"old_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
|
|||
|
" param_grid=old_param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
|
|||
|
"\n",
|
|||
|
"old_grid_search.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"old_best_params = old_grid_search.best_params_\n",
|
|||
|
"old_best_mse = -old_grid_search.best_score_\n",
|
|||
|
"\n",
|
|||
|
"new_param_grid = {\n",
|
|||
|
" 'n_estimators': [100],\n",
|
|||
|
" 'max_depth': [20],\n",
|
|||
|
" 'min_samples_split': [10]\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
|
|||
|
" param_grid=new_param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', cv=2)\n",
|
|||
|
"\n",
|
|||
|
"new_grid_search.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"new_best_params = new_grid_search.best_params_\n",
|
|||
|
"new_best_mse = -new_grid_search.best_score_\n",
|
|||
|
"\n",
|
|||
|
"model_best = RandomForestRegressor(**new_best_params)\n",
|
|||
|
"model_best.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"model_oldbest = RandomForestRegressor(**old_best_params)\n",
|
|||
|
"model_oldbest.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"y_pred = model_best.predict(X_test)\n",
|
|||
|
"y_oldpred = model_oldbest.predict(X_test)\n",
|
|||
|
"\n",
|
|||
|
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
|
|||
|
"rmse = np.sqrt(mse)\n",
|
|||
|
"\n",
|
|||
|
"print(\"Старые параметры:\", old_best_params)\n",
|
|||
|
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
|
|||
|
"print(\"\\n\")\n",
|
|||
|
"print(\"Новые параметры:\", new_best_params)\n",
|
|||
|
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
|
|||
|
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
|
|||
|
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 45,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAHWCAYAAABjUYhTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzde1xUdf748dcMAzMwzAzCAIIwgTcYtEy0zEtpW5u23aBau7ibdr9T36ys3dKs7bLbZY1q29y1zLbdsi3YLr91t5tdXLNUKpUhrziAKAyXYRiYAWbm98fI5AAaIjhc3s/Hg4fOOWfO+Zwzt/P+XN4fhc/n8yGEEEIIIYQQotcpQ10AIYQQQgghhBisJOASQgghhBBCiD4iAZcQQgghhBBC9BEJuIQQQgghhBCij0jAJYQQQgghhBB9RAIuIYQQQgghhOgjEnAJIYQQQgghRB+RgEsIIYQQQggh+ogEXEIIIYQQQgjRRyTgEkIIIYQQQog+IgGXEEIchZUrV6JQKNi4cWOX62fNmsX48eOPc6mEEEII0V9JwCWEEEIIIYQQfUQCLiGEEEIIIYToIxJwCSFEH2tra+ORRx5h1KhRqNVq0tLS+M1vfoPb7Q7aLi0tDYVCgUKhQKlUMnz4cC677DKsVmtgm9LSUhQKBU899dRhj/fQQw+hUCg6Lf/b3/7GpEmTiIyMJDY2lssvv5yysrKfLH9X+2tsbGT48OEoFArWrl37k/uoqKjg2muvJTk5GbVaTXp6OjfffDMtLS2BbppH+lu5ciUA33//PQsWLGDkyJFoNBqGDx/ONddcQ01NTZdlLikpYe7cuej1euLi4rjjjjtwuVxB2yoUCm677bbDlr29fKWlpUHL//3vf3P66aej1WrR6XScd955bNu27SevxU+d70MPPRTYdu/evdxyyy1kZGQQGRlJXFwcv/zlLzuVpX2fn3/+OTfeeCNxcXHo9Xquuuoq6urqgrb917/+xXnnnRd4LUaNGsUjjzyCx+MJ2m7WrFkoFApycnI6ncONN96IQqHo1H3W6/WybNkyxo0bh0ajITExkRtvvDGoDIe+z7v6S0tLA4Lf63/84x854YQTiIyMZObMmWzdujXouAsWLCA6OvqI173jtRVCiONFFeoCCCHEQGS327HZbJ2Wt7a2dlp23XXX8eqrr3LppZeycOFCNmzYwOOPP47FYqGgoCBo29NPP50bbrgBr9fL1q1bWbZsGfv27eOLL744pvI++uijPPjgg8ydO5frrruO6upqnnvuOc444wyKioqIiYk5qv09/fTTHDhwoFvb7tu3j1NPPZX6+npuuOEGMjMzqaio4J///CdNTU2cccYZvPbaa0FlBfjtb38bWDZt2jQAPvzwQ3bv3s3VV1/N8OHD2bZtG8uXL2fbtm189dVXnQLDuXPnkpaWxuOPP85XX31Ffn4+dXV1rFq16qjOt6PXXnuN+fPnM3v2bH7/+9/T1NTEiy++yIwZMygqKgoEDUfy8MMPk56eHnjc2NjIzTffHLTNN998w//+9z8uv/xyUlJSKC0t5cUXX2TWrFkUFxcTFRUVtP1tt91GTEwMDz30ED/88AMvvvgie/fuZe3atYFrs3LlSqKjo7nrrruIjo7mk08+YfHixTQ0NPDkk08G7U+j0fDBBx9QVVVFQkICAM3Nzbz55ptoNJpO53TjjTeycuVKrr76avLy8tizZw/PP/88RUVFrFu3jvDwcJYtW0ZjYyMAFouFxx57jN/85jeYzWaAToHTqlWrcDgc3HrrrbhcLp599ll+9rOfsWXLFhITE3/yOgshRMj5hBBCdNsrr7ziA474N27cuMD23377rQ/wXXfddUH7ufvuu32A75NPPgksO+GEE3zz588P2u7KK6/0RUVFBR7v2bPHB/iefPLJw5ZxyZIlvkO/3ktLS31hYWG+Rx99NGi7LVu2+FQqVaflP7W/qqoqn06n85177rk+wPfpp58e8flXXXWVT6lU+r755ptO67xeb6dlM2fO9M2cObPLfTU1NXVa9o9//MMH+D7//PNOZb7wwguDtr3lllt8gO+7774LLAN8t95662HL3/6a79mzx+fz+XwOh8MXExPju/7664O2279/v89gMHRafrj9dbwe1dXVPsC3ZMmSI57v+vXrfYBv1apVnfY5adIkX0tLS2D5H/7wBx/g+9e//nXEfd54442+qKgon8vlCiybOXOmb9y4cb6TTjrJ99RTTwWWv/baa76UlBTf6aefHvRe/+KLL3yA7/XXXw/a95o1a7pc7vP5fJ9++ulh30Pt7/XIyEhfeXl5YPmGDRt8gO///u//Asvmz5/v02q1nfZxqI7XVgghjhfpUiiEED3wwgsv8OGHH3b6O+mkk4K2+3//7/8BcNdddwUtX7hwIQAffPBB0HK3243NZqOqqooPP/yQTz75hLPOOqvT8ZuamrDZbNTV1eHz+Y5Y1nfeeQev18vcuXOx2WyBv+HDhzNmzBg+/fTTozr3Rx55BIPBQF5e3k9u6/V6KSws5IILLmDy5Mmd1nfV9fFIIiMjA/93uVzYbDZOO+00ADZv3txp+1tvvTXo8e233w78+Lp03FdNTQ1er/eIZfjwww+pr6/niiuuCLqeYWFhTJky5aiv55Ecer6tra3U1NQwevRoYmJiujzfG264gfDw8MDjm2++GZVKFXS+h+7T4XBgs9k4/fTTaWpqoqSkpNM+r776al555ZXA41deeYX58+ejVAbfQrz11lsYDAZ+/vOfB12XSZMmER0d3ePrkpOTw4gRIwKPTz31VKZMmdLpNQQCx+zYbVQIIUJJuhQKIUQPnHrqqV0GEMOGDQvqarh3716USiWjR48O2m748OHExMSwd+/eoOVvvPEGb7zxRuDxKaecwl//+tdOx1myZAlLliwB/N2+fvazn7Fs2TLGjBnTadsdO3bg8/m6XAcE3aD/lD179vDSSy/x4osvdtmlrKPq6moaGhp6LVV+bW0tS5cu5Y033qCqqipond1u77R9x3MeNWoUSqWy0xioFStWsGLFCgAiIiKYMmUKzzzzTJev8Y4dOwD42c9+1mUZ9Xp9t8/npzQ3N/P444/zyiuvUFFRERRcd+d8o6OjSUpKCjrfbdu28cADD/DJJ5/Q0NAQtH1X+5w3bx733nsvX3/9NQkJCaxdu5aXXnqJL7/8Mmi7HTt2YLfbA10PO+r4enVXV+/bsWPHsnr16qBlTqeT+Pj4wOPU1FQWLlzIHXfc0aPjCiFEb5GASwghjoPutuScc8453HPPPQCUl5fz+9//njPPPJONGzcGtUzccMMN/PKXv8Tj8WCxWHjooYfIycnpMmmD1+tFoVDw73//m7CwsE7rfyrZwKF++9vfMmbMGObPn3/M48p6Yu7cufzvf//jnnvu4eSTTyY6Ohqv18ucOXN+smUKDv86XHTRRdx22234fD727NnDww8/zPnnnx8Irg7VfpzXXnuN4cOHd1qvUvXeT+vtt9/OK6+8wp133snUqVMxGAwoFAouv/zybp1vR/X19cycORO9Xs/DDz/MqFGj0Gg0bN68mUWLFnW5z/j4eC644AJeeeUVEhMTmT59eqcKBPBfl4SEBF5//fUuj31oMNQXNBoN7733HuBvuXv55Ze58847SUpKYu7cuX16bCGEOBIJuIQQog+dcMIJeL1eduzYEUgKAHDgwAHq6+s54YQTgrZPSkri7LPPDjzOyMhg2rRpFBYWcsUVVwSWjxkzJrDd7NmzaWpq4re//W1QRsN2o0aNwufzkZ6eztixY3t8LkVFRbzxxhsUFhZ2Gbh1JT4+Hr1e3ymrXE/U1dXx8ccfs3TpUhYvXhxY3lVQdOi6QxNT7Ny5E6/X2ympRUpKStB1j46OZt68eRQVFXXa56hRowBISEgIek5f+Oc//8n8+fN5+umnA8tcLhf19fVdbr9jxw7OPPPMwOPGxkYqKyv5xS9+AcDatWupqanhnXfe4Ywzzghst2fPniOW45prrmHevHkYDIbDZvobNWoUH330EdOnTw+qHDhWXb2+27dv7/QahoWFBb0e5513HrGxsaxZs0Y
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1000x500 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"plt.figure(figsize=(10, 5))\n",
|
|||
|
"plt.scatter(range(len(y_test)), y_test, label=\"Актуальные значения\", color=\"green\", alpha=0.5)\n",
|
|||
|
"plt.scatter(range(len(y_test)), y_pred, label=\"Новые параметры\", color=\"blue\", alpha=0.5)\n",
|
|||
|
"plt.scatter(range(len(y_test)), y_oldpred, label=\"Старые параметры\", color=\"red\", alpha=0.5)\n",
|
|||
|
"plt.xlabel(\"Выборка\")\n",
|
|||
|
"plt.ylabel(\"Значения\")\n",
|
|||
|
"plt.legend()\n",
|
|||
|
"plt.title(\"Новые и старые параметры\")\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": ".venv",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.12.5"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|