AIM-PIbd-31-Afanasev-S-S/lab_4/lab4.ipynb

3903 lines
386 KiB
Plaintext
Raw Normal View History

2024-11-16 00:44:40 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Начало лабораторной работы"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Вариант 3:* Диабет у индейцев Пима "
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['Pregnancies', 'Glucose', 'BloodPressure', 'SkinThickness', 'Insulin',\n",
" 'BMI', 'DiabetesPedigreeFunction', 'Age', 'Outcome'],\n",
" dtype='object')\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>6</td>\n",
" <td>148</td>\n",
" <td>72</td>\n",
" <td>35</td>\n",
" <td>0</td>\n",
" <td>33.6</td>\n",
" <td>0.627</td>\n",
" <td>50</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>85</td>\n",
" <td>66</td>\n",
" <td>29</td>\n",
" <td>0</td>\n",
" <td>26.6</td>\n",
" <td>0.351</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>8</td>\n",
" <td>183</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>23.3</td>\n",
" <td>0.672</td>\n",
" <td>32</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>89</td>\n",
" <td>66</td>\n",
" <td>23</td>\n",
" <td>94</td>\n",
" <td>28.1</td>\n",
" <td>0.167</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>137</td>\n",
" <td>40</td>\n",
" <td>35</td>\n",
" <td>168</td>\n",
" <td>43.1</td>\n",
" <td>2.288</td>\n",
" <td>33</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>763</th>\n",
" <td>10</td>\n",
" <td>101</td>\n",
" <td>76</td>\n",
" <td>48</td>\n",
" <td>180</td>\n",
" <td>32.9</td>\n",
" <td>0.171</td>\n",
" <td>63</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>764</th>\n",
" <td>2</td>\n",
" <td>122</td>\n",
" <td>70</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" <td>36.8</td>\n",
" <td>0.340</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>765</th>\n",
" <td>5</td>\n",
" <td>121</td>\n",
" <td>72</td>\n",
" <td>23</td>\n",
" <td>112</td>\n",
" <td>26.2</td>\n",
" <td>0.245</td>\n",
" <td>30</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>766</th>\n",
" <td>1</td>\n",
" <td>126</td>\n",
" <td>60</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30.1</td>\n",
" <td>0.349</td>\n",
" <td>47</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>767</th>\n",
" <td>1</td>\n",
" <td>93</td>\n",
" <td>70</td>\n",
" <td>31</td>\n",
" <td>0</td>\n",
" <td>30.4</td>\n",
" <td>0.315</td>\n",
" <td>23</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>768 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"0 6 148 72 35 0 33.6 \n",
"1 1 85 66 29 0 26.6 \n",
"2 8 183 64 0 0 23.3 \n",
"3 1 89 66 23 94 28.1 \n",
"4 0 137 40 35 168 43.1 \n",
".. ... ... ... ... ... ... \n",
"763 10 101 76 48 180 32.9 \n",
"764 2 122 70 27 0 36.8 \n",
"765 5 121 72 23 112 26.2 \n",
"766 1 126 60 0 0 30.1 \n",
"767 1 93 70 31 0 30.4 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"0 0.627 50 1 \n",
"1 0.351 31 0 \n",
"2 0.672 32 1 \n",
"3 0.167 21 0 \n",
"4 2.288 33 1 \n",
".. ... ... ... \n",
"763 0.171 63 0 \n",
"764 0.340 27 0 \n",
"765 0.245 30 0 \n",
"766 0.349 47 1 \n",
"767 0.315 23 0 \n",
"\n",
"[768 rows x 9 columns]"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"from sklearn import set_config\n",
"\n",
"# Установим параметры для вывода\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"random_state = 42\n",
"\n",
"# Подключим датафрейм и выгрузим данные\n",
"df = pd.read_csv(\"C:/Users/TIGR228/Desktop/МИИ/Lab1/AIM-PIbd-31-Afanasev-S-S/static/csv/diabetes.csv\")\n",
"print(df.columns)\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Бизнес-цели:\n",
"\n",
"1. Прогнозирование риска развития диабета\n",
"\n",
"Описание: Классифицировать пациентов на основе их медицинских данных для определения риска развития диабета (используя целевой признак \"Outcome\"). Эта задача актуальна для раннего выявления диабета и разработки профилактических мер, направленных на улучшение здоровья населения.\n",
"\n",
"2. Оценка факторов, влияющих на развитие диабета\n",
"\n",
"Описание: Предсказать вероятность развития диабета у новых пациентов на основе их медицинских характеристик (таких как уровень глюкозы, артериальное давление, индекс массы тела и другие параметры). Это позволит медицинским специалистам планировать лечение и наблюдение в зависимости от индивидуальных рисков пациентов. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Определение достижимого уровня качества модели для первой задачи "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
"Целевой признак -- Outcome"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>353</th>\n",
" <td>1</td>\n",
" <td>90</td>\n",
" <td>62</td>\n",
" <td>12</td>\n",
" <td>43</td>\n",
" <td>27.2</td>\n",
" <td>0.580</td>\n",
" <td>24</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>711</th>\n",
" <td>5</td>\n",
" <td>126</td>\n",
" <td>78</td>\n",
" <td>27</td>\n",
" <td>22</td>\n",
" <td>29.6</td>\n",
" <td>0.439</td>\n",
" <td>40</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>373</th>\n",
" <td>2</td>\n",
" <td>105</td>\n",
" <td>58</td>\n",
" <td>40</td>\n",
" <td>94</td>\n",
" <td>34.9</td>\n",
" <td>0.225</td>\n",
" <td>25</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46</th>\n",
" <td>1</td>\n",
" <td>146</td>\n",
" <td>56</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>29.7</td>\n",
" <td>0.564</td>\n",
" <td>29</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>682</th>\n",
" <td>0</td>\n",
" <td>95</td>\n",
" <td>64</td>\n",
" <td>39</td>\n",
" <td>105</td>\n",
" <td>44.6</td>\n",
" <td>0.366</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>451</th>\n",
" <td>2</td>\n",
" <td>134</td>\n",
" <td>70</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>28.9</td>\n",
" <td>0.542</td>\n",
" <td>23</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>113</th>\n",
" <td>4</td>\n",
" <td>76</td>\n",
" <td>62</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>34.0</td>\n",
" <td>0.391</td>\n",
" <td>25</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>556</th>\n",
" <td>1</td>\n",
" <td>97</td>\n",
" <td>70</td>\n",
" <td>40</td>\n",
" <td>0</td>\n",
" <td>38.1</td>\n",
" <td>0.218</td>\n",
" <td>30</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>667</th>\n",
" <td>10</td>\n",
" <td>111</td>\n",
" <td>70</td>\n",
" <td>27</td>\n",
" <td>0</td>\n",
" <td>27.5</td>\n",
" <td>0.141</td>\n",
" <td>40</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>107</th>\n",
" <td>4</td>\n",
" <td>144</td>\n",
" <td>58</td>\n",
" <td>28</td>\n",
" <td>140</td>\n",
" <td>29.5</td>\n",
" <td>0.287</td>\n",
" <td>37</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"353 1 90 62 12 43 27.2 \n",
"711 5 126 78 27 22 29.6 \n",
"373 2 105 58 40 94 34.9 \n",
"46 1 146 56 0 0 29.7 \n",
"682 0 95 64 39 105 44.6 \n",
".. ... ... ... ... ... ... \n",
"451 2 134 70 0 0 28.9 \n",
"113 4 76 62 0 0 34.0 \n",
"556 1 97 70 40 0 38.1 \n",
"667 10 111 70 27 0 27.5 \n",
"107 4 144 58 28 140 29.5 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"353 0.580 24 0 \n",
"711 0.439 40 0 \n",
"373 0.225 25 0 \n",
"46 0.564 29 0 \n",
"682 0.366 22 0 \n",
".. ... ... ... \n",
"451 0.542 23 1 \n",
"113 0.391 25 0 \n",
"556 0.218 30 0 \n",
"667 0.141 40 1 \n",
"107 0.287 37 0 \n",
"\n",
"[614 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>353</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>711</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>373</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>682</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>451</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>113</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>556</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>667</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>107</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Outcome\n",
"353 0\n",
"711 0\n",
"373 0\n",
"46 0\n",
"682 0\n",
".. ...\n",
"451 1\n",
"113 0\n",
"556 0\n",
"667 1\n",
"107 0\n",
"\n",
"[614 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>44</th>\n",
" <td>7</td>\n",
" <td>159</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>27.4</td>\n",
" <td>0.294</td>\n",
" <td>40</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>672</th>\n",
" <td>10</td>\n",
" <td>68</td>\n",
" <td>106</td>\n",
" <td>23</td>\n",
" <td>49</td>\n",
" <td>35.5</td>\n",
" <td>0.285</td>\n",
" <td>47</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>700</th>\n",
" <td>2</td>\n",
" <td>122</td>\n",
" <td>76</td>\n",
" <td>27</td>\n",
" <td>200</td>\n",
" <td>35.9</td>\n",
" <td>0.483</td>\n",
" <td>26</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>630</th>\n",
" <td>7</td>\n",
" <td>114</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>27.4</td>\n",
" <td>0.732</td>\n",
" <td>34</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>81</th>\n",
" <td>2</td>\n",
" <td>74</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.102</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32</th>\n",
" <td>3</td>\n",
" <td>88</td>\n",
" <td>58</td>\n",
" <td>11</td>\n",
" <td>54</td>\n",
" <td>24.8</td>\n",
" <td>0.267</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>637</th>\n",
" <td>2</td>\n",
" <td>94</td>\n",
" <td>76</td>\n",
" <td>18</td>\n",
" <td>66</td>\n",
" <td>31.6</td>\n",
" <td>0.649</td>\n",
" <td>23</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>593</th>\n",
" <td>2</td>\n",
" <td>82</td>\n",
" <td>52</td>\n",
" <td>22</td>\n",
" <td>115</td>\n",
" <td>28.5</td>\n",
" <td>1.699</td>\n",
" <td>25</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>425</th>\n",
" <td>4</td>\n",
" <td>184</td>\n",
" <td>78</td>\n",
" <td>39</td>\n",
" <td>277</td>\n",
" <td>37.0</td>\n",
" <td>0.264</td>\n",
" <td>31</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>273</th>\n",
" <td>1</td>\n",
" <td>71</td>\n",
" <td>78</td>\n",
" <td>50</td>\n",
" <td>45</td>\n",
" <td>33.2</td>\n",
" <td>0.422</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>154 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"44 7 159 64 0 0 27.4 \n",
"672 10 68 106 23 49 35.5 \n",
"700 2 122 76 27 200 35.9 \n",
"630 7 114 64 0 0 27.4 \n",
"81 2 74 0 0 0 0.0 \n",
".. ... ... ... ... ... ... \n",
"32 3 88 58 11 54 24.8 \n",
"637 2 94 76 18 66 31.6 \n",
"593 2 82 52 22 115 28.5 \n",
"425 4 184 78 39 277 37.0 \n",
"273 1 71 78 50 45 33.2 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"44 0.294 40 0 \n",
"672 0.285 47 0 \n",
"700 0.483 26 0 \n",
"630 0.732 34 1 \n",
"81 0.102 22 0 \n",
".. ... ... ... \n",
"32 0.267 22 0 \n",
"637 0.649 23 0 \n",
"593 1.699 25 0 \n",
"425 0.264 31 1 \n",
"273 0.422 21 0 \n",
"\n",
"[154 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>44</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>672</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>700</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>630</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>81</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>637</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>593</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>425</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>273</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>154 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Outcome\n",
"44 0\n",
"672 0\n",
"700 0\n",
"630 1\n",
"81 0\n",
".. ...\n",
"32 0\n",
"637 0\n",
"593 0\n",
"425 1\n",
"273 0\n",
"\n",
"[154 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Устанавливаем случайное состояние\n",
"random_state = 42\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" X = df_input # Contains all columns.\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
" # Split original dataframe into train and temp dataframes.\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
" # Split the temp dataframe into val and test dataframes.\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"Outcome\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование конвейера для классификации данных\n",
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
"\n",
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
"\n",
"features_preprocessing -- трансформер для предобработки признаков\n",
"\n",
"features_engineering -- трансформер для конструирования признаков\n",
"\n",
"drop_columns -- трансформер для удаления колонок\n",
"\n",
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.preprocessing import OneHotEncoder, StandardScaler\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Построение конвейеров предобработки\n",
"\n",
"class DiabetesFeatures(BaseEstimator, TransformerMixin):\n",
" def __init__(self):\n",
" pass\n",
"\n",
" def fit(self, X, y=None):\n",
" return self\n",
"\n",
" def transform(self, X, y=None):\n",
" # Создание новых признаков\n",
" X = X.copy()\n",
" X[\"BMI_to_Age_ratio\"] = X[\"BMI\"] / X[\"Age\"]\n",
" return X\n",
"\n",
" def get_feature_names_out(self, features_in):\n",
" # Добавление имен новых признаков\n",
" new_features = [\"BMI_to_Age_ratio\"]\n",
" return np.append(features_in, new_features, axis=0)\n",
"\n",
"# Обработка числовых данных. Числовой конвейр: заполнение пропущенных значений медианой и стандартизация\n",
"preprocessing_num_class = Pipeline(steps=[\n",
" ('imputer', SimpleImputer(strategy='median')),\n",
" ('scaler', StandardScaler())\n",
"])\n",
"\n",
"preprocessing_cat_class = Pipeline(steps=[\n",
" ('imputer', SimpleImputer(strategy='most_frequent')),\n",
" ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False, drop='first'))\n",
"])\n",
"\n",
"columns_to_drop = []\n",
"numeric_columns = [\"Pregnancies\", \"Glucose\", \"BloodPressure\", \"SkinThickness\", \"Insulin\",\n",
" \"BMI\", \"DiabetesPedigreeFunction\", \"Age\"]\n",
"cat_columns = [\"Outcome\"]\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"preprocessing_num\", preprocessing_num_class, numeric_columns),\n",
" (\"preprocessing_cat\", preprocessing_cat_class, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"features_postprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" ('preprocessing_cat', preprocessing_cat_class, [\"Outcome\"]),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"custom_features\", DiabetesFeatures()),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Демонстрация работы конвейера"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome_1</th>\n",
" <th>BMI_to_Age_ratio</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>353</th>\n",
" <td>-0.851355</td>\n",
" <td>-0.980131</td>\n",
" <td>-0.404784</td>\n",
" <td>-0.553973</td>\n",
" <td>-0.331319</td>\n",
" <td>-0.607678</td>\n",
" <td>0.310794</td>\n",
" <td>-0.792169</td>\n",
" <td>0.0</td>\n",
" <td>0.767107</td>\n",
" </tr>\n",
" <tr>\n",
" <th>711</th>\n",
" <td>0.356576</td>\n",
" <td>0.161444</td>\n",
" <td>0.465368</td>\n",
" <td>0.392787</td>\n",
" <td>-0.526398</td>\n",
" <td>-0.302139</td>\n",
" <td>-0.116439</td>\n",
" <td>0.561034</td>\n",
" <td>0.0</td>\n",
" <td>-0.538540</td>\n",
" </tr>\n",
" <tr>\n",
" <th>373</th>\n",
" <td>-0.549372</td>\n",
" <td>-0.504474</td>\n",
" <td>-0.622322</td>\n",
" <td>1.213312</td>\n",
" <td>0.142444</td>\n",
" <td>0.372594</td>\n",
" <td>-0.764862</td>\n",
" <td>-0.707594</td>\n",
" <td>0.0</td>\n",
" <td>-0.526564</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46</th>\n",
" <td>-0.851355</td>\n",
" <td>0.795653</td>\n",
" <td>-0.731091</td>\n",
" <td>-1.311380</td>\n",
" <td>-0.730766</td>\n",
" <td>-0.289408</td>\n",
" <td>0.262314</td>\n",
" <td>-0.369293</td>\n",
" <td>0.0</td>\n",
" <td>0.783681</td>\n",
" </tr>\n",
" <tr>\n",
" <th>682</th>\n",
" <td>-1.153338</td>\n",
" <td>-0.821579</td>\n",
" <td>-0.296015</td>\n",
" <td>1.150195</td>\n",
" <td>0.244628</td>\n",
" <td>1.607482</td>\n",
" <td>-0.337630</td>\n",
" <td>-0.961320</td>\n",
" <td>0.0</td>\n",
" <td>-1.672162</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>451</th>\n",
" <td>-0.549372</td>\n",
" <td>0.415128</td>\n",
" <td>0.030292</td>\n",
" <td>-1.311380</td>\n",
" <td>-0.730766</td>\n",
" <td>-0.391255</td>\n",
" <td>0.195653</td>\n",
" <td>-0.876744</td>\n",
" <td>1.0</td>\n",
" <td>0.446259</td>\n",
" </tr>\n",
" <tr>\n",
" <th>113</th>\n",
" <td>0.054593</td>\n",
" <td>-1.424076</td>\n",
" <td>-0.404784</td>\n",
" <td>-1.311380</td>\n",
" <td>-0.730766</td>\n",
" <td>0.258017</td>\n",
" <td>-0.261879</td>\n",
" <td>-0.707594</td>\n",
" <td>0.0</td>\n",
" <td>-0.364639</td>\n",
" </tr>\n",
" <tr>\n",
" <th>556</th>\n",
" <td>-0.851355</td>\n",
" <td>-0.758158</td>\n",
" <td>0.030292</td>\n",
" <td>1.213312</td>\n",
" <td>-0.730766</td>\n",
" <td>0.779980</td>\n",
" <td>-0.786072</td>\n",
" <td>-0.284718</td>\n",
" <td>0.0</td>\n",
" <td>-2.739481</td>\n",
" </tr>\n",
" <tr>\n",
" <th>667</th>\n",
" <td>1.866489</td>\n",
" <td>-0.314212</td>\n",
" <td>0.030292</td>\n",
" <td>0.392787</td>\n",
" <td>-0.730766</td>\n",
" <td>-0.569486</td>\n",
" <td>-1.019383</td>\n",
" <td>0.561034</td>\n",
" <td>1.0</td>\n",
" <td>-1.015065</td>\n",
" </tr>\n",
" <tr>\n",
" <th>107</th>\n",
" <td>0.054593</td>\n",
" <td>0.732232</td>\n",
" <td>-0.622322</td>\n",
" <td>0.455904</td>\n",
" <td>0.569759</td>\n",
" <td>-0.314870</td>\n",
" <td>-0.577001</td>\n",
" <td>0.307308</td>\n",
" <td>0.0</td>\n",
" <td>-1.024606</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>614 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"353 -0.851355 -0.980131 -0.404784 -0.553973 -0.331319 -0.607678 \n",
"711 0.356576 0.161444 0.465368 0.392787 -0.526398 -0.302139 \n",
"373 -0.549372 -0.504474 -0.622322 1.213312 0.142444 0.372594 \n",
"46 -0.851355 0.795653 -0.731091 -1.311380 -0.730766 -0.289408 \n",
"682 -1.153338 -0.821579 -0.296015 1.150195 0.244628 1.607482 \n",
".. ... ... ... ... ... ... \n",
"451 -0.549372 0.415128 0.030292 -1.311380 -0.730766 -0.391255 \n",
"113 0.054593 -1.424076 -0.404784 -1.311380 -0.730766 0.258017 \n",
"556 -0.851355 -0.758158 0.030292 1.213312 -0.730766 0.779980 \n",
"667 1.866489 -0.314212 0.030292 0.392787 -0.730766 -0.569486 \n",
"107 0.054593 0.732232 -0.622322 0.455904 0.569759 -0.314870 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome_1 BMI_to_Age_ratio \n",
"353 0.310794 -0.792169 0.0 0.767107 \n",
"711 -0.116439 0.561034 0.0 -0.538540 \n",
"373 -0.764862 -0.707594 0.0 -0.526564 \n",
"46 0.262314 -0.369293 0.0 0.783681 \n",
"682 -0.337630 -0.961320 0.0 -1.672162 \n",
".. ... ... ... ... \n",
"451 0.195653 -0.876744 1.0 0.446259 \n",
"113 -0.261879 -0.707594 0.0 -0.364639 \n",
"556 -0.786072 -0.284718 0.0 -2.739481 \n",
"667 -1.019383 0.561034 1.0 -1.015065 \n",
"107 -0.577001 0.307308 0.0 -1.024606 \n",
"\n",
"[614 rows x 10 columns]"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование набора моделей для классификации\n",
" logistic -- логистическая регрессия\n",
"\n",
"ridge -- гребневая регрессия\n",
"\n",
"decision_tree -- дерево решений\n",
"\n",
"knn -- k-ближайших соседей\n",
"\n",
"naive_bayes -- наивный Байесовский классификатор\n",
"\n",
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
"\n",
"random_forest -- метод случайного леса (набор деревьев решений)\n",
"\n",
"mlp -- многослойный персептрон (нейронная сеть)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"# Определите random_state для воспроизводимости результатов\n",
"random_state = 42\n",
"\n",
"# Определите модели машинного обучения для классификации\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression(random_state=random_state)},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\", random_state=random_state)},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210, random_state=random_state)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение моделей на обучающем наборе данных и оценка на тестовом"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Сводная таблица оценок качества для использованных моделей классификации"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA3EAAAQ9CAYAAAD3ScTVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVwU9f8H8NcAwiI3iiwoIt544pWRZ4ZhpWna1yz6hebR5Z1nphze9jUJz7IU7auZlZqWWh55k+WZJt4omKAmAoJy7X5+f5CbKyzu6gA7w+v5fczj637mM7Of2YSX79nPzEhCCAEiIiIiIiJSBJvyHgARERERERGZj0UcERERERGRgrCIIyIiIiIiUhAWcURERERERArCIo6IiIiIiEhBWMQREREREREpCIs4IiIiIiIiBWERR0REREREpCAs4oiIiIiIiBSERRxVeHFxcZAkCZcuXSqV/V+6dAmSJCEuLk6W/e3atQuSJGHXrl2y7I+IiEgtIiMjIUmSWX0lSUJkZGTpDoiolLCII7JSixYtkq3wIyIiIiL1sCvvARCpnb+/P+7evYtKlSpZtN2iRYtQtWpV9O/f36i9Y8eOuHv3Luzt7WUcJRERkfJ9+OGHmDBhQnkPg6jUsYgjKmWSJEGj0ci2PxsbG1n3R0REpAbZ2dlwcnKCnR3/eUvqx+mURMVYtGgRGjduDAcHB/j6+uK9995Denp6kX4LFy5E7dq14ejoiCeeeAJ79+5F586d0blzZ0Of4q6JS01NxYABA1CjRg04ODjAx8cHPXv2NFyXV6tWLfz555/YvXs3JEmCJEmGfZq6Ju7gwYN4/vnn4eHhAScnJzRr1gyffPKJvB8MERGRFbh37dupU6fw2muvwcPDA+3bty/2mrjc3FyMGjUKXl5ecHFxwYsvvogrV64Uu99du3ahdevW0Gg0qFOnDj799FOT19n973//Q6tWreDo6AhPT0/069cPycnJpXK8RA/iqQqiB0RGRiIqKgohISF45513cObMGSxevBi///479u/fb5gWuXjxYgwdOhQdOnTAqFGjcOnSJfTq1QseHh6oUaNGie/Rp08f/Pnnnxg2bBhq1aqF69evY9u2bUhKSkKtWrUQExODYcOGwdnZGZMmTQIAeHt7m9zftm3b0L17d/j4+GDEiBHQarVISEjADz/8gBEjRsj34RAREVmR//znP6hXrx5mzJgBIQSuX79epM+gQYPwv//9D6+99hqeeuop7Ny5Ey+88EKRfkePHkW3bt3g4+ODqKgo6HQ6REdHw8vLq0jf6dOnY/Lkyejbty8GDRqEGzduYP78+ejYsSOOHj0Kd3f30jhcon8Jogpu+fLlAoBITEwU169fF/b29uLZZ58VOp3O0GfBggUCgFi2bJkQQojc3FxRpUoV0aZNG5Gfn2/oFxcXJwCITp06GdoSExMFALF8+XIhhBC3bt0SAMRHH31U4rgaN25stJ97fvnlFwFA/PLLL0IIIQoKCkRAQIDw9/cXt27dMuqr1+vN/yCIiIgUIiIiQgAQr776arHt9xw7dkwAEO+++65Rv9dee00AEBEREYa2Hj16iMqVK4u//vrL0Hbu3DlhZ2dntM9Lly4JW1tbMX36dKN9njhxQtjZ2RVpJyoNnE5JdJ/t27cjLy8PI0eOhI3Nvz8egwcPhqurK3788UcAwKFDh3Dz5k0MHjzYaO59WFgYPDw8SnwPR0dH2NvbY9euXbh169Zjj/no0aNITEzEyJEji5z5M/c2y0REREr09ttvl7h+8+bNAIDhw4cbtY8cOdLotU6nw/bt29GrVy/4+voa2uvWrYvnnnvOqO+6deug1+vRt29f/P3334ZFq9WiXr16+OWXXx7jiIjMw+mURPe5fPkyAKBBgwZG7fb29qhdu7Zh/b3/r1u3rlE/Ozs71KpVq8T3cHBwwOzZs/H+++/D29sbTz75JLp374433ngDWq3W4jFfuHABANCkSROLtyUiIlKygICAEtdfvnwZNjY2qFOnjlH7gzl//fp13L17t0iuA0Wz/ty5cxBCoF69esW+p6V3oyZ6FCziiMrByJEj0aNHD2zYsAE//fQTJk+ejJkzZ2Lnzp1o0aJFeQ+PiIhIERwdHcv8PfV6PSRJwpYtW2Bra1tkvbOzc5mPiSoeTqckuo+/vz8A4MyZM0bteXl5SExMNKy/9//nz5836ldQUGC4w+TD1KlTB++//z5+/vlnnDx5Enl5eZg7d65hvblTIe+dXTx58qRZ/YmIiCoKf39/6PV6w6yVex7M+WrVqkGj0RTJdaBo1tepUwdCCAQEBCAkJKTI8uSTT8p/IEQPYBFHdJ+QkBDY29sjNjYWQghD+xdffIGMjAzD3axat26NKlWqYOnSpSgoKDD0W7Vq1UOvc7tz5w5ycnKM2urUqQMXFxfk5uYa2pycnIp9rMGDWrZsiYCAAMTExBTpf/8xEBERVTT3rmeLjY01ao+JiTF6bWtri5CQEGzYsAFXr141tJ8/fx5btmwx6tu7d2/Y2toiKiqqSM4KIXDz5k0Zj4CoeJxOSXQfLy8vTJw4EVFRUejWrRtefPFFnDlzBosWLUKbNm3w+uuvAyi8Ri4yMhLDhg1Dly5d0LdvX1y6dAlxcXGoU6dOid+inT17Fs888wz69u2LRo0awc7ODuvXr8e1a9fQr18/Q79WrVph8eLFmDZtGurWrYtq1aqhS5cuRfZnY2ODxYsXo0ePHggKCsKAAQPg4+OD06dP488//8RPP/0k/wdFRESkAEFBQXj11VexaNEiZGRk4KmnnsKOHTuK/cYtMjISP//8M9q1a4d33nkHOp0OCxYsQJMmTXDs2DFDvzp16mDatGmYOHGi4fFCLi4uSExMxPr16zFkyBCMGTOmDI+SKiIWcUQPiIyMhJeXFxYsWIBRo0bB09MTQ4YMwYwZM4wuVh46dCiEEJg7dy7GjBmD5s2bY+PGjRg+fDg0Go3J/fv5+eHVV1/Fjh078OWXX8LOzg4NGzbE2rVr0adPH0O/KVOm4PLly5gzZw5u376NTp06FVvEAUBoaCh++eUXREVFYe7cudDr9ahTpw4GDx4s3wdDRESkQMuWLYOXlxdWrVqFDRs2oEuXLvjxxx/h5+dn1K9Vq1bYsmULxowZg8mTJ8PPzw/R0dFISEjA6dOnjfpOmDAB9evXx7x58xAVFQWgMN+fffZZvPjii2V2bFRxSYLzrYhko9fr4eXlhd69e2Pp0qXlPRwiIiJ6TL169cKff/6Jc+fOlfdQiAx4TRzRI8rJySkyF37lypVIS0tD586dy2dQRERE9Mju3r1r9PrcuXPYvHkzc52sDr+JI3pEu3btwqhRo/Cf//wHVapUwZEjR/DFF18gMDAQhw8fhr29fXkPkYiIiCzg4+OD/v37G54Nu3jxYuTm5uLo0aMmnwtHVB54TRzRI6pVqxb8/PwQGxuLtLQ0eHp64o033sCsWbNYwBERESlQt27d8NVXXyE1NRUODg4IDg7GjBkzWMCR1eF0SqJHVKtWLWzcuBGpqanIy8tDamoqli1bhmrVqpX30Egl9uzZgx49esDX1xeSJGHDhg1G64UQmDJlCnx8fODo6IiQkJAi12ykpaUhLCwMrq6ucHd3x8CBA5GVlVWGR0FEpBzLly/HpUuXkJOTg4yMDGzduhUtW7Ys72GRFbGWbGYRR0RkpbKzs9G8eXMsXLiw2PVz5sxBbGwslixZgoMHD8LJyQmhoaFGzyEMCwvDn3/+iW3btuGHH37Anj17MGTIkLI6BCIiIlWxlmzmNXFERAogSRLWr1+PXr16ASg80+fr64v333/f8DyijIwMeHt7Iy4uDv369UNCQgIaNWqE33//Ha1btwYAbN26Fc8//zyuXLkCX1/f8jocIiIixSvPbOY1cRWIXq/H1atX4eLiUuLDqInUSgiB27dvw9fXFzY28k5EyMnJQV5e3kPf/8GfPQcHBzg4OFj8fomJiUhNTUVISIihzc3NDW3btkV8fDz69euH+Ph4uLu7G0ICAEJCQmBjY4ODBw/ipZdesvh9iUhezGaq6JjNj5bNLOIqkKt
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Определите количество строк и столбцов для subplots\n",
"n_rows = int(len(class_models) / 2)\n",
"n_cols = 2\n",
"\n",
"fig, ax = plt.subplots(n_rows, n_cols, figsize=(12, 10), sharex=False, sharey=False)\n",
"\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"No Diabetes\", \"Diabetes\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"# Настройте расположение subplots\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"100 - количество истинных положительных диагнозов (True Positives), где модель правильно определила объекты как \"No Diabetes\".\n",
"\n",
"54 в некоторых моделях - количество ложных отрицательных диагнозов (False Negatives), где модель неправильно определила объекты, которые на самом деле принадлежат к классу \"No Diabetes\", но были отнесены к классу \"Diabetes\". \n",
"\n",
"Исходя из значений True Positives и False Negatives, можно сказать, что модель имеет высокую точность при предсказании класса \"No Diabetes\". В принципе, уровень ложных отрицательных результатов в некоторых моделях (54) говорит нам о том, что существует некотрое небольшое количество примеров, которые модель пропускает.\n",
"\n",
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_90d0d_row0_col0, #T_90d0d_row0_col1, #T_90d0d_row0_col2, #T_90d0d_row0_col3, #T_90d0d_row1_col0, #T_90d0d_row1_col1, #T_90d0d_row1_col2, #T_90d0d_row1_col3, #T_90d0d_row2_col0, #T_90d0d_row2_col1, #T_90d0d_row2_col2, #T_90d0d_row2_col3, #T_90d0d_row3_col0, #T_90d0d_row3_col1, #T_90d0d_row3_col2, #T_90d0d_row3_col3, #T_90d0d_row4_col0, #T_90d0d_row4_col1, #T_90d0d_row4_col2, #T_90d0d_row4_col3, #T_90d0d_row5_col0, #T_90d0d_row5_col1, #T_90d0d_row5_col2, #T_90d0d_row5_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_90d0d_row0_col4, #T_90d0d_row0_col5, #T_90d0d_row0_col6, #T_90d0d_row0_col7, #T_90d0d_row1_col4, #T_90d0d_row1_col5, #T_90d0d_row1_col6, #T_90d0d_row1_col7, #T_90d0d_row2_col4, #T_90d0d_row2_col5, #T_90d0d_row2_col6, #T_90d0d_row2_col7, #T_90d0d_row3_col4, #T_90d0d_row3_col5, #T_90d0d_row3_col6, #T_90d0d_row3_col7, #T_90d0d_row4_col4, #T_90d0d_row4_col5, #T_90d0d_row4_col6, #T_90d0d_row4_col7, #T_90d0d_row5_col4, #T_90d0d_row5_col5, #T_90d0d_row5_col6, #T_90d0d_row5_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_90d0d_row6_col0 {\n",
" background-color: #89d548;\n",
" color: #000000;\n",
"}\n",
"#T_90d0d_row6_col1 {\n",
" background-color: #5ac864;\n",
" color: #000000;\n",
"}\n",
"#T_90d0d_row6_col2 {\n",
" background-color: #77d153;\n",
" color: #000000;\n",
"}\n",
"#T_90d0d_row6_col3 {\n",
" background-color: #65cb5e;\n",
" color: #000000;\n",
"}\n",
"#T_90d0d_row6_col4 {\n",
" background-color: #c5407e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_90d0d_row6_col5 {\n",
" background-color: #b22b8f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_90d0d_row6_col6 {\n",
" background-color: #cc4977;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_90d0d_row6_col7 {\n",
" background-color: #c03a83;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_90d0d_row7_col0, #T_90d0d_row7_col1, #T_90d0d_row7_col2, #T_90d0d_row7_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_90d0d_row7_col4, #T_90d0d_row7_col5, #T_90d0d_row7_col6, #T_90d0d_row7_col7 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_90d0d\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_90d0d_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_90d0d_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_90d0d_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_90d0d_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_90d0d_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_90d0d_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_90d0d_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_90d0d_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_90d0d_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_90d0d_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_90d0d_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_90d0d_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_90d0d_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_90d0d_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_90d0d_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_90d0d_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_90d0d_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_90d0d_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_90d0d_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_90d0d_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_90d0d_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_90d0d_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_90d0d_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_90d0d_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_90d0d_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_90d0d_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_90d0d_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_90d0d_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_90d0d_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_90d0d_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_90d0d_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_90d0d_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" <td id=\"T_90d0d_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
" <td id=\"T_90d0d_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
" <td id=\"T_90d0d_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_90d0d_level0_row3\" class=\"row_heading level0 row3\" >naive_bayes</th>\n",
" <td id=\"T_90d0d_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_90d0d_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_90d0d_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_90d0d_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_90d0d_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" <td id=\"T_90d0d_row3_col5\" class=\"data row3 col5\" >1.000000</td>\n",
" <td id=\"T_90d0d_row3_col6\" class=\"data row3 col6\" >1.000000</td>\n",
" <td id=\"T_90d0d_row3_col7\" class=\"data row3 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_90d0d_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
" <td id=\"T_90d0d_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_90d0d_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_90d0d_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_90d0d_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_90d0d_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
" <td id=\"T_90d0d_row4_col5\" class=\"data row4 col5\" >1.000000</td>\n",
" <td id=\"T_90d0d_row4_col6\" class=\"data row4 col6\" >1.000000</td>\n",
" <td id=\"T_90d0d_row4_col7\" class=\"data row4 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_90d0d_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
" <td id=\"T_90d0d_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
" <td id=\"T_90d0d_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
" <td id=\"T_90d0d_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
" <td id=\"T_90d0d_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
" <td id=\"T_90d0d_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
" <td id=\"T_90d0d_row5_col5\" class=\"data row5 col5\" >1.000000</td>\n",
" <td id=\"T_90d0d_row5_col6\" class=\"data row5 col6\" >1.000000</td>\n",
" <td id=\"T_90d0d_row5_col7\" class=\"data row5 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_90d0d_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
" <td id=\"T_90d0d_row6_col0\" class=\"data row6 col0\" >0.918367</td>\n",
" <td id=\"T_90d0d_row6_col1\" class=\"data row6 col1\" >0.777778</td>\n",
" <td id=\"T_90d0d_row6_col2\" class=\"data row6 col2\" >0.841121</td>\n",
" <td id=\"T_90d0d_row6_col3\" class=\"data row6 col3\" >0.777778</td>\n",
" <td id=\"T_90d0d_row6_col4\" class=\"data row6 col4\" >0.918567</td>\n",
" <td id=\"T_90d0d_row6_col5\" class=\"data row6 col5\" >0.844156</td>\n",
" <td id=\"T_90d0d_row6_col6\" class=\"data row6 col6\" >0.878049</td>\n",
" <td id=\"T_90d0d_row6_col7\" class=\"data row6 col7\" >0.777778</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_90d0d_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_90d0d_row7_col0\" class=\"data row7 col0\" >0.254237</td>\n",
" <td id=\"T_90d0d_row7_col1\" class=\"data row7 col1\" >0.238095</td>\n",
" <td id=\"T_90d0d_row7_col2\" class=\"data row7 col2\" >0.070093</td>\n",
" <td id=\"T_90d0d_row7_col3\" class=\"data row7 col3\" >0.092593</td>\n",
" <td id=\"T_90d0d_row7_col4\" class=\"data row7 col4\" >0.604235</td>\n",
" <td id=\"T_90d0d_row7_col5\" class=\"data row7 col5\" >0.577922</td>\n",
" <td id=\"T_90d0d_row7_col6\" class=\"data row7 col6\" >0.109890</td>\n",
" <td id=\"T_90d0d_row7_col7\" class=\"data row7 col7\" >0.133333</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x2664f38b440>"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Все модели в данной выборке, а именно логистическая регрессия, ридж-регрессия, дерево решений, KNN, наивный байесовский классификатор, градиентный бустинг, случайный лес и многослойный перцептрон (MLP) демонстрируют неплохие значения по всем метрикам на обучающих и тестовых наборах данных.\n",
"\n",
"Модели MLP не так эффективна по сравнению с другими, но в некоторых метриках показывают высокие результаты. "
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_296b9_row0_col0, #T_296b9_row0_col1, #T_296b9_row1_col0, #T_296b9_row1_col1, #T_296b9_row2_col0, #T_296b9_row2_col1, #T_296b9_row3_col0, #T_296b9_row3_col1, #T_296b9_row4_col0, #T_296b9_row4_col1, #T_296b9_row5_col0, #T_296b9_row5_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_296b9_row0_col2, #T_296b9_row0_col3, #T_296b9_row0_col4, #T_296b9_row1_col2, #T_296b9_row1_col3, #T_296b9_row1_col4, #T_296b9_row2_col2, #T_296b9_row2_col3, #T_296b9_row2_col4, #T_296b9_row3_col2, #T_296b9_row3_col3, #T_296b9_row3_col4, #T_296b9_row4_col2, #T_296b9_row4_col3, #T_296b9_row4_col4, #T_296b9_row5_col2, #T_296b9_row5_col3, #T_296b9_row5_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_296b9_row6_col0 {\n",
" background-color: #48c16e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_296b9_row6_col1 {\n",
" background-color: #63cb5f;\n",
" color: #000000;\n",
"}\n",
"#T_296b9_row6_col2 {\n",
" background-color: #c8437b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_296b9_row6_col3, #T_296b9_row6_col4 {\n",
" background-color: #b83289;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_296b9_row7_col0, #T_296b9_row7_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_296b9_row7_col2, #T_296b9_row7_col3, #T_296b9_row7_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_296b9\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_296b9_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_296b9_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_296b9_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_296b9_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_296b9_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_296b9_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_296b9_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_296b9_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_296b9_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_296b9_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_296b9_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_296b9_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_296b9_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_296b9_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_296b9_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_296b9_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_296b9_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_296b9_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_296b9_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_296b9_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_296b9_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_296b9_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_296b9_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_296b9_level0_row3\" class=\"row_heading level0 row3\" >naive_bayes</th>\n",
" <td id=\"T_296b9_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_296b9_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_296b9_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_296b9_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_296b9_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_296b9_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
" <td id=\"T_296b9_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_296b9_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_296b9_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_296b9_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_296b9_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_296b9_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
" <td id=\"T_296b9_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
" <td id=\"T_296b9_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
" <td id=\"T_296b9_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
" <td id=\"T_296b9_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
" <td id=\"T_296b9_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_296b9_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
" <td id=\"T_296b9_row6_col0\" class=\"data row6 col0\" >0.844156</td>\n",
" <td id=\"T_296b9_row6_col1\" class=\"data row6 col1\" >0.777778</td>\n",
" <td id=\"T_296b9_row6_col2\" class=\"data row6 col2\" >0.908056</td>\n",
" <td id=\"T_296b9_row6_col3\" class=\"data row6 col3\" >0.657778</td>\n",
" <td id=\"T_296b9_row6_col4\" class=\"data row6 col4\" >0.657778</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_296b9_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_296b9_row7_col0\" class=\"data row7 col0\" >0.577922</td>\n",
" <td id=\"T_296b9_row7_col1\" class=\"data row7 col1\" >0.133333</td>\n",
" <td id=\"T_296b9_row7_col2\" class=\"data row7 col2\" >0.488148</td>\n",
" <td id=\"T_296b9_row7_col3\" class=\"data row7 col3\" >-0.078431</td>\n",
" <td id=\"T_296b9_row7_col4\" class=\"data row7 col4\" >-0.093728</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x2664f38a8a0>"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Схожий вывод можно сделать и для следующих метрик: Accuracy, F1, ROC AUC, Cohen's Kappa и MCC. Все модели, кроме KNN и MLP, указывают на хорошо-развитую способность к выделению классов"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 0'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Predicted</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [Pregnancies, Predicted, Glucose, BloodPressure, SkinThickness, Insulin, BMI, DiabetesPedigreeFunction, Age, Outcome]\n",
"Index: []"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"Outcome\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Пример использования обученной модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>163</th>\n",
" <td>2.0</td>\n",
" <td>100.0</td>\n",
" <td>64.0</td>\n",
" <td>23.0</td>\n",
" <td>0.0</td>\n",
" <td>29.7</td>\n",
" <td>0.368</td>\n",
" <td>21.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"163 2.0 100.0 64.0 23.0 0.0 29.7 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"163 0.368 21.0 0.0 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome_1</th>\n",
" <th>BMI_to_Age_ratio</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>163</th>\n",
" <td>-0.549372</td>\n",
" <td>-0.663027</td>\n",
" <td>-0.296015</td>\n",
" <td>0.140318</td>\n",
" <td>-0.730766</td>\n",
" <td>-0.289408</td>\n",
" <td>-0.33157</td>\n",
" <td>-1.045895</td>\n",
" <td>0.0</td>\n",
" <td>0.276709</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"163 -0.549372 -0.663027 -0.296015 0.140318 -0.730766 -0.289408 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome_1 BMI_to_Age_ratio \n",
"163 -0.33157 -1.045895 0.0 0.276709 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: 0 (proba: [0.98965692 0.01034308])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 0'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 163\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Подбор гиперпараметров методом поиска по сетке "
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\TIGR228\\Desktop\\МИИ\\Lab1\\AIM-PIbd-31-Afanasev-S-S\\aimenv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 5,\n",
" 'model__max_features': 'sqrt',\n",
" 'model__n_estimators': 10}"
]
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"optimized_model_type = \"random_forest\"\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 50, 100],\n",
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
" \"model__max_depth\": [5, 7, 10],\n",
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"import pandas as pd\n",
"\n",
"\n",
"# Определяем числовые признаки\n",
"numeric_features = X_train.select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
"\n",
"# Установка random_state\n",
"random_state = 42\n",
"\n",
"# Определение трансформера\n",
"pipeline_end = ColumnTransformer([\n",
" ('numeric', StandardScaler(), numeric_features),\n",
" # Добавьте другие трансформеры, если требуется\n",
"])\n",
"\n",
"# Объявление модели\n",
"optimized_model = RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"sqrt\",\n",
" n_estimators=50,\n",
")\n",
"\n",
"# Создание пайплайна с корректными шагами\n",
"result = {}\n",
"\n",
"# Обучение модели\n",
"result[\"pipeline\"] = Pipeline([\n",
" (\"pipeline\", pipeline_end),\n",
" (\"model\", optimized_model)\n",
"]).fit(X_train, y_train.values.ravel())\n",
"\n",
"# Прогнозирование и расчет метрик\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"# Метрики для оценки модели\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_a9e08_row0_col0, #T_a9e08_row0_col1, #T_a9e08_row0_col2, #T_a9e08_row0_col3, #T_a9e08_row1_col0, #T_a9e08_row1_col1, #T_a9e08_row1_col2, #T_a9e08_row1_col3 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a9e08_row0_col4, #T_a9e08_row0_col5, #T_a9e08_row0_col6, #T_a9e08_row0_col7, #T_a9e08_row1_col4, #T_a9e08_row1_col5, #T_a9e08_row1_col6, #T_a9e08_row1_col7 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_a9e08\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_a9e08_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_a9e08_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_a9e08_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_a9e08_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_a9e08_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_a9e08_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_a9e08_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_a9e08_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" <th class=\"blank col7\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_a9e08_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_a9e08_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_a9e08_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_a9e08_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_a9e08_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_a9e08_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_a9e08_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_a9e08_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_a9e08_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a9e08_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_a9e08_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_a9e08_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_a9e08_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_a9e08_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_a9e08_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_a9e08_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_a9e08_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_a9e08_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x26651c516d0>"
]
},
"execution_count": 92,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_d47a4_row0_col0, #T_d47a4_row0_col1, #T_d47a4_row1_col0, #T_d47a4_row1_col1 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d47a4_row0_col2, #T_d47a4_row0_col3, #T_d47a4_row0_col4, #T_d47a4_row1_col2, #T_d47a4_row1_col3, #T_d47a4_row1_col4 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_d47a4\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_d47a4_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_d47a4_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_d47a4_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_d47a4_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_d47a4_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_d47a4_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_d47a4_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_d47a4_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_d47a4_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_d47a4_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_d47a4_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_d47a4_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_d47a4_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_d47a4_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_d47a4_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_d47a4_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_d47a4_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x26651b45c10>"
]
},
"execution_count": 93,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA5IAAAGxCAYAAAAQ1omjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABOw0lEQVR4nO3dfVxUZf7/8fcBBLzhRrzhRhE177A0tVxzK+0Gwy0rs77mLruLpba7ZaZmVlugYom2aS5W6mZJ7s/WbFNTK3fLVss0S1OrlbS8CU3RNhWEFgRmfn+wTjsLKsMczgxzXs/H4zyCczfXTMibz7mucx3D6XQ6BQAAAABALQX5ugEAAAAAgIaFQhIAAAAA4BEKSQAAAACARygkAQAAAAAeoZAEAAAAAHiEQhIAAAAA4BEKSQAAAACARygkAQAAAAAeoZAEAAAAAHiEQhIAAAAA4BEKSQBAQHv//fd18803KyEhQYZhaNWqVW7bnU6nMjMzFR8fr8aNGyslJUVfffWV2z4nTpxQWlqaIiMjFR0drVGjRqm4uNjCdwEAQBV/yTUKSQBAQCspKdGll16q5557rsbtTz31lHJycrRgwQJt3bpVTZs2VWpqqkpLS137pKWl6Z///KfeeecdrV27Vu+//77uueceq94CAAAu/pJrhtPpdHr1TgAAaCAMw9DKlSs1dOhQSVVXbRMSEvTggw9q0qRJkqTCwkLFxsYqNzdXI0aMUF5enrp3765PPvlEl19+uSRp3bp1uvHGG3X48GElJCT46u0AAGzOl7kWUi/vCACA/1JaWqozZ86Ydj6n0ynDMNzWhYWFKSwszKPzHDhwQAUFBUpJSXGti4qKUr9+/bRlyxaNGDFCW7ZsUXR0tCtsJSklJUVBQUHaunWrbrvtNu/eDACgwSHXKCQBAPWstLRUHZKaqeB4pWnnbNasWbV7OaZMmaKpU6d6dJ6CggJJUmxsrNv62NhY17aCggK1bt3abXtISIhiYmJc+wAA7INc+88xHrUMAAAPnTlzRgXHK3Vge5IiI7y/Nb/otEMdLvtGhw4dUmRkpGu9p1dtAQCoC3KtCoUkAMASkRFBpgSu63yRkW6BWxdxcXGSpGPHjik+Pt61/tixY+rVq5drn+PHj7sdV1FRoRMnTriOBwDYj91zjVlbAQCWqHQ6TFvM0qFDB8XFxWn9+vWudUVFRdq6dav69+8vSerfv79OnTql7du3u/Z577335HA41K9fP9PaAgBoWOyea/RIAgAs4ZBTDnk/Ubin5yguLtbXX3/t+v7AgQPauXOnYmJi1K5dO40fP15PPPGEOnfurA4dOigjI0MJCQmuGfCSk5M1ePBgjRkzRgsWLFB5ebnGjh2rESNGMGMrANiY3XONQhIAENC2bduma6+91vX9xIkTJUnp6enKzc3V5MmTVVJSonvuuUenTp3SVVddpXXr1ik8PNx1zNKlSzV27Fhdf/31CgoK0u23366cnBzL3wsAAP6SazxHEgBQr4qKihQVFaUje9qaNilBQtfDKiws9PpeEgAAPEWuVaFHEgBgiUqnU5UmXLs04xwAAHjL7rnGZDsAAAAAAI/QIwkAsISvJiUAAKA+2D3XKCQBAJZwyKlKGwcuACCw2D3XGNoKAAAAAPAIPZIAAEvYfQgQACCw2D3X6JEEAAAAAHiEHkkAgCXsPk06ACCw2D3XKCQBAJZw/Gcx4zwAAPia3XONoa0AAAAAAI/QIwkAsESlSdOkm3EOAAC8Zfdco5AEAFii0lm1mHEeAAB8ze65xtBWAAAAAIBH6JEEAFjC7pMSAAACi91zjUISAGAJhwxVyjDlPAAA+Jrdc42hrQAAAAAAj9AjCQCwhMNZtZhxHgAAfM3uuUaPJAAAAADAI/RIAgAsUWnSvSRmnAMAAG/ZPdcoJAEAlrB74AIAAovdc42hrQAAAAAAj9AjCQCwhMNpyOE0YZp0E84BAIC37J5rFJIAAEvYfQgQACCw2D3XGNoKAAAAAPAIPZIAAEtUKkiVJly/rDShLQAAeMvuuUYhCQCwhNOke0mcDfReEgBAYLF7rjG0FQAAAADgEXokAQCWsPukBACAwGL3XKOQBABYotIZpEqnCfeSOE1oDAAAXrJ7rjG0FQAAAADgEXokAQCWcMiQw4Trlw410Eu3AICAYvdco0cSAAAAAOAReiQBAJaw+6QEAIDAYvdco5AEAFjCvEkJGuYQIABAYLF7rjG0FQAAAADgEXokAQCWqJqUwPvhO2acAwAAb9k91ygkAQCWcChIlTae3Q4AEFjsnmsMbQUAAAAAeIQeSQCAJew+KQEAILDYPdcoJAEAlnAoyNYPbgYABBa75xpDWwEAAAAAHqFHEgBgiUqnoUqnCQ9uNuEcAAB4y+65Ro8kAAAAAMAj9EgCACxRadI06ZUN9F4SAEBgsXuuUUgCACzhcAbJYcLsdo4GOrsdACCw2D3XGNoKAAAAAPAIPZIAAEvYfQgQACCw2D3XKCQBAJZwyJyZ6RzeNwUAAK/ZPdcY2goAAAAA8Ag9kgAASzgUJIcJ1y/NOAcAAN6ye65RSAIALFHpDFKlCbPbmXEOAAC8Zfdca5itBgAAAAD4DD2SAABLOGTIITMmJfD+HAAAeMvuuUYhCQCwhN2HAAEAAovdc61hthoAAAAA4DP0SAIALGHeg5u5BgoA8D2751rDbDUAAAAAwGfokbQRh8OhI0eOKCIiQobRMG/qBWAtp9Op06dPKyEhQUFB3l17dDgNOZwmTEpgwjkQGMg1AJ4i18xDIWkjR44cUWJioq+bAaABOnTokNq2bevVORwmDQFqqA9uhvnINQB1Ra55j0LSRiIiIiRJ33zaXpHNGuYPLOrPbV16+LoJ8EMVKtcmveX6/QH4E3IN50OuoSbkmnkoJG3k7LCfyGZBiowgcOEuxGjk6ybAHzmr/mPGsEGHM0gOE6Y4N+McCAzkGs6HXEONyDXTUEgCACxRKUOVJjx02YxzAADgLbvnWsMsfwEAAAAAPkOPJADAEnYfAgQACCx2zzUKSQCAJSplzvCdSu+bAgCA1+yeaw2z/AUAAAAA+Aw9kgAAS9h9CBAAILDYPdcaZqsBAAAAAD5DjyQAwBKVziBVmnDV1YxzAADgLbvnWsNsNQCgwXHKkMOExenBxAaVlZXKyMhQhw4d1LhxY1100UWaPn26nE7nj+1yOpWZman4+Hg1btxYKSkp+uqrr+rjIwAABBC75xqFJAAgYM2aNUvz58/Xs88+q7y8PM2aNUtPPfWU5s2b59rnqaeeUk5OjhYsWKCtW7eqadOmSk1NVWlpqQ9bDgBAdf6UawxtBQBYwhdDgDZv3qxbb71VN910kySpffv2+stf/qKPP/5YUtVV27lz5+rxxx/XrbfeKklasmSJYmNjtWrVKo0YMcLr9gIAApPdc40eSQCAJRxOw7RFkoqKityWsrKyaq/505/+VOvXr9fevXslSbt27dKmTZv0s5/9TJJ04MABFRQUKCUlxXVMVFSU+vXrpy1btljwqQAAGiq75xo9kgCABikxMdHt+ylTpmjq1Klu6x555BEVFRWpW7duCg4OVmVlpZ588kmlpaVJkgoKCiRJsbGxbsfFxsa6tgEAYIWGlmsUkgAAS1QqSJUmDIQ5e45Dhw4pMjLStT4sLKzavsuXL9fSpUv1yiuv6OKLL9bOnTs1fvx4JSQkKD093eu2AADsy+65RiEJALDEfw/f8fY8khQZGekWuDV56KGH9Mgjj7juCenRo4e++eYbZWdnKz09XXFxcZKkY8eOKT4+3nXcsWPH1KtXL6/bCgAIXHbPNe6RBAAErB9++EFBQe5RFxwcLIfDIUnq0KGD4uLitH79etf2oqIibd26Vf3797e0rQAAXIg/5Ro9kgAASzgUJIcJ1y89OcfNN9+sJ598Uu3atdPFF1+sHTt2aM6cObr77rslSYZhaPz48XriiSfUuXNndejQQRkZGUpISNDQoUO9bisAIHDZPdcoJAEAlqh0Gqo0YQiQJ+eYN2+eMjIydO+99+r48eNKSEjQb37zG2VmZrr2mTx5skpKSnTPPffo1KlTuuq
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"No Diabetes\", \"Diabetes\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В желтых квадрате мы наблюдаем значение 100, что обозначает количество правильно классифицированных объектов, отнесенных к классу \"No Diabetes\". Это свидетельствует о том, что модель успешно идентифицирует объекты этого класса, минимизируя количество ложных положительных срабатываний.\n",
"\n",
"В бирюзовом квадрате значение 0 указывает на количество правильно классифицированных объектов, отнесенных к классу \"Diabetes\". Это является показателем не такой высокой точности модели в определении объектов данного класса."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Определение достижимого уровня качества модели для второй задачи (задача регрессии)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(700, 8)\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>6</td>\n",
" <td>98</td>\n",
" <td>58</td>\n",
" <td>33</td>\n",
" <td>190</td>\n",
" <td>34.0</td>\n",
" <td>0.430</td>\n",
" <td>43</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>112</td>\n",
" <td>75</td>\n",
" <td>32</td>\n",
" <td>0</td>\n",
" <td>35.7</td>\n",
" <td>0.148</td>\n",
" <td>21</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2</td>\n",
" <td>108</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30.8</td>\n",
" <td>0.158</td>\n",
" <td>21</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>8</td>\n",
" <td>107</td>\n",
" <td>80</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>24.6</td>\n",
" <td>0.856</td>\n",
" <td>34</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>7</td>\n",
" <td>136</td>\n",
" <td>90</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>29.9</td>\n",
" <td>0.210</td>\n",
" <td>50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>695</th>\n",
" <td>2</td>\n",
" <td>105</td>\n",
" <td>80</td>\n",
" <td>45</td>\n",
" <td>191</td>\n",
" <td>33.7</td>\n",
" <td>0.711</td>\n",
" <td>29</td>\n",
" </tr>\n",
" <tr>\n",
" <th>696</th>\n",
" <td>1</td>\n",
" <td>126</td>\n",
" <td>56</td>\n",
" <td>29</td>\n",
" <td>152</td>\n",
" <td>28.7</td>\n",
" <td>0.801</td>\n",
" <td>21</td>\n",
" </tr>\n",
" <tr>\n",
" <th>697</th>\n",
" <td>2</td>\n",
" <td>95</td>\n",
" <td>54</td>\n",
" <td>14</td>\n",
" <td>88</td>\n",
" <td>26.1</td>\n",
" <td>0.748</td>\n",
" <td>22</td>\n",
" </tr>\n",
" <tr>\n",
" <th>698</th>\n",
" <td>3</td>\n",
" <td>100</td>\n",
" <td>68</td>\n",
" <td>23</td>\n",
" <td>81</td>\n",
" <td>31.6</td>\n",
" <td>0.949</td>\n",
" <td>28</td>\n",
" </tr>\n",
" <tr>\n",
" <th>699</th>\n",
" <td>1</td>\n",
" <td>85</td>\n",
" <td>66</td>\n",
" <td>29</td>\n",
" <td>0</td>\n",
" <td>26.6</td>\n",
" <td>0.351</td>\n",
" <td>31</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>700 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"0 6 98 58 33 190 34.0 \n",
"1 2 112 75 32 0 35.7 \n",
"2 2 108 64 0 0 30.8 \n",
"3 8 107 80 0 0 24.6 \n",
"4 7 136 90 0 0 29.9 \n",
".. ... ... ... ... ... ... \n",
"695 2 105 80 45 191 33.7 \n",
"696 1 126 56 29 152 28.7 \n",
"697 2 95 54 14 88 26.1 \n",
"698 3 100 68 23 81 31.6 \n",
"699 1 85 66 29 0 26.6 \n",
"\n",
" DiabetesPedigreeFunction Age \n",
"0 0.430 43 \n",
"1 0.148 21 \n",
"2 0.158 21 \n",
"3 0.856 34 \n",
"4 0.210 50 \n",
".. ... ... \n",
"695 0.711 29 \n",
"696 0.801 21 \n",
"697 0.748 22 \n",
"698 0.949 28 \n",
"699 0.351 31 \n",
"\n",
"[700 rows x 8 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import set_config\n",
"\n",
"# Установите random_state для воспроизводимости результатов\n",
"random_state = 42\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"df = pd.read_csv(\"C:/Users/TIGR228/Desktop/МИИ/Lab1/AIM-PIbd-31-Afanasev-S-S/static/csv/diabetes.csv\")\n",
"\n",
"# Удалите столбцы, которые не нужны для анализа\n",
"df = df.drop(columns=[\"Outcome\"])\n",
"\n",
"df = df.sample(n=700, random_state=random_state).reset_index(drop=True)\n",
"\n",
"print(df.shape) \n",
"display(df)"
]
},
{
"cell_type": "code",
2024-11-16 12:52:11 +04:00
"execution_count": 2,
2024-11-16 00:44:40 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"0 6 148 72 35 0 33.6 \n",
"1 1 85 66 29 0 26.6 \n",
"2 8 183 64 0 0 23.3 \n",
"3 1 89 66 23 94 28.1 \n",
"4 0 137 40 35 168 43.1 \n",
"\n",
" DiabetesPedigreeFunction Age diabetes_risk_index \n",
"0 0.627 50 71.68 \n",
"1 0.351 31 46.28 \n",
"2 0.672 32 74.69 \n",
"3 0.167 21 55.33 \n",
"4 2.288 33 81.43 \n"
]
}
],
"source": [
"import pandas as pd\n",
"\n",
"df = pd.read_csv(\"C:/Users/TIGR228/Desktop/МИИ/Lab1/AIM-PIbd-31-Afanasev-S-S/static/csv/diabetes.csv\")\n",
"\n",
"required_columns = [\"Pregnancies\", \"Glucose\", \"BloodPressure\", \"SkinThickness\", \"Insulin\", \"BMI\", \"DiabetesPedigreeFunction\", \"Age\"]\n",
"missing_columns = [col for col in required_columns if col not in df.columns]\n",
"if missing_columns:\n",
" raise ValueError(f\"Отсутствуют столбцы: {missing_columns}\")\n",
"\n",
"df[\"diabetes_risk_index\"] = (\n",
" df[\"Glucose\"] * 0.3 \n",
" + df[\"BMI\"] * 0.3 \n",
" + df[\"Age\"] * 0.2 \n",
" + df[\"BloodPressure\"] * 0.1 \n",
" + df[\"Insulin\"] * 0.1 \n",
")\n",
"\n",
"# Проверка новых данных\n",
"print(df[[\"Pregnancies\", \"Glucose\", \"BloodPressure\", \"SkinThickness\", \"Insulin\", \"BMI\", \"DiabetesPedigreeFunction\", \"Age\", \"diabetes_risk_index\"]].head())\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи регрессии "
]
},
{
"cell_type": "code",
2024-11-16 12:52:11 +04:00
"execution_count": 3,
2024-11-16 00:44:40 +04:00
"metadata": {},
2024-11-16 12:52:11 +04:00
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>60</th>\n",
" <td>2</td>\n",
" <td>84</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.304</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>618</th>\n",
" <td>9</td>\n",
" <td>112</td>\n",
" <td>82</td>\n",
" <td>24</td>\n",
" <td>0</td>\n",
" <td>28.2</td>\n",
" <td>1.282</td>\n",
" <td>50</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>346</th>\n",
" <td>1</td>\n",
" <td>139</td>\n",
" <td>46</td>\n",
" <td>19</td>\n",
" <td>83</td>\n",
" <td>28.7</td>\n",
" <td>0.654</td>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>294</th>\n",
" <td>0</td>\n",
" <td>161</td>\n",
" <td>50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>21.9</td>\n",
" <td>0.254</td>\n",
" <td>65</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>231</th>\n",
" <td>6</td>\n",
" <td>134</td>\n",
" <td>80</td>\n",
" <td>37</td>\n",
" <td>370</td>\n",
" <td>46.2</td>\n",
" <td>0.238</td>\n",
" <td>46</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"60 2 84 0 0 0 0.0 \n",
"618 9 112 82 24 0 28.2 \n",
"346 1 139 46 19 83 28.7 \n",
"294 0 161 50 0 0 21.9 \n",
"231 6 134 80 37 370 46.2 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"60 0.304 21 0 \n",
"618 1.282 50 1 \n",
"346 0.654 22 0 \n",
"294 0.254 65 0 \n",
"231 0.238 46 1 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>diabetes_risk_index</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>60</th>\n",
" <td>29.40</td>\n",
" </tr>\n",
" <tr>\n",
" <th>618</th>\n",
" <td>60.26</td>\n",
" </tr>\n",
" <tr>\n",
" <th>346</th>\n",
" <td>67.61</td>\n",
" </tr>\n",
" <tr>\n",
" <th>294</th>\n",
" <td>72.87</td>\n",
" </tr>\n",
" <tr>\n",
" <th>231</th>\n",
" <td>108.26</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" diabetes_risk_index\n",
"60 29.40\n",
"618 60.26\n",
"346 67.61\n",
"294 72.87\n",
"231 108.26"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pregnancies</th>\n",
" <th>Glucose</th>\n",
" <th>BloodPressure</th>\n",
" <th>SkinThickness</th>\n",
" <th>Insulin</th>\n",
" <th>BMI</th>\n",
" <th>DiabetesPedigreeFunction</th>\n",
" <th>Age</th>\n",
" <th>Outcome</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>668</th>\n",
" <td>6</td>\n",
" <td>98</td>\n",
" <td>58</td>\n",
" <td>33</td>\n",
" <td>190</td>\n",
" <td>34.0</td>\n",
" <td>0.430</td>\n",
" <td>43</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>324</th>\n",
" <td>2</td>\n",
" <td>112</td>\n",
" <td>75</td>\n",
" <td>32</td>\n",
" <td>0</td>\n",
" <td>35.7</td>\n",
" <td>0.148</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>624</th>\n",
" <td>2</td>\n",
" <td>108</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30.8</td>\n",
" <td>0.158</td>\n",
" <td>21</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>690</th>\n",
" <td>8</td>\n",
" <td>107</td>\n",
" <td>80</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>24.6</td>\n",
" <td>0.856</td>\n",
" <td>34</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>473</th>\n",
" <td>7</td>\n",
" <td>136</td>\n",
" <td>90</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>29.9</td>\n",
" <td>0.210</td>\n",
" <td>50</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pregnancies Glucose BloodPressure SkinThickness Insulin BMI \\\n",
"668 6 98 58 33 190 34.0 \n",
"324 2 112 75 32 0 35.7 \n",
"624 2 108 64 0 0 30.8 \n",
"690 8 107 80 0 0 24.6 \n",
"473 7 136 90 0 0 29.9 \n",
"\n",
" DiabetesPedigreeFunction Age Outcome \n",
"668 0.430 43 0 \n",
"324 0.148 21 0 \n",
"624 0.158 21 0 \n",
"690 0.856 34 0 \n",
"473 0.210 50 0 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>diabetes_risk_index</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>668</th>\n",
" <td>73.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>324</th>\n",
" <td>56.01</td>\n",
" </tr>\n",
" <tr>\n",
" <th>624</th>\n",
" <td>52.24</td>\n",
" </tr>\n",
" <tr>\n",
" <th>690</th>\n",
" <td>54.28</td>\n",
" </tr>\n",
" <tr>\n",
" <th>473</th>\n",
" <td>68.77</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" diabetes_risk_index\n",
"668 73.00\n",
"324 56.01\n",
"624 52.24\n",
"690 54.28\n",
"473 68.77"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_into_train_test(\n",
" df_input: DataFrame,\n",
" target_colname: str = \"diabetes_risk_index\",\n",
" frac_train: float = 0.8,\n",
" random_state: int = None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" if not (0 < frac_train < 1):\n",
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
" \n",
" # Проверка наличия целевого признака\n",
" if target_colname not in df_input.columns:\n",
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
" \n",
" # Разделяем данные на признаки и целевую переменную\n",
" X = df_input.drop(columns=[target_colname]) # Признаки\n",
" y = df_input[[target_colname]] # Целевая переменная\n",
"\n",
" # Разделяем данные на обучающую и тестовую выборки\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y,\n",
" test_size=(1.0 - frac_train),\n",
" random_state=random_state\n",
" )\n",
" \n",
" return X_train, X_test, y_train, y_test\n",
"\n",
"# Применение функции для разделения данных\n",
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
" df, \n",
" target_colname=\"diabetes_risk_index\", \n",
" frac_train=0.8, \n",
" random_state=42\n",
")\n",
"\n",
"# Для отображения результатов\n",
"display(\"X_train\", X_train.head())\n",
"display(\"y_train\", y_train.head())\n",
"\n",
"display(\"X_test\", X_test.head())\n",
"display(\"y_test\", y_test.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Определение перечня алгоритмов решения задачи аппроксимации (регрессии)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures, StandardScaler\n",
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
"\n",
"# Установка random_state для воспроизводимости\n",
"random_state = 42\n",
"\n",
"# Словарь моделей\n",
"models = {\n",
" # Линейная регрессия\n",
" \"linear\": {\n",
" \"model\": linear_model.LinearRegression(n_jobs=-1)\n",
" },\n",
" # Полиномиальная регрессия степени 2\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" StandardScaler(), # Добавляем масштабирование данных\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" # Полиномиальная регрессия с взаимодействиями\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" StandardScaler(), # Масштабирование данных\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" # Ridge-регрессия\n",
" \"ridge\": {\n",
" \"model\": make_pipeline(\n",
" StandardScaler(), # Масштабирование данных\n",
" linear_model.RidgeCV()\n",
" )\n",
" },\n",
" # Регрессия на основе дерева решений\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
" },\n",
" # Метод ближайших соседей (kNN)\n",
" \"knn\": {\n",
" \"model\": make_pipeline(\n",
" StandardScaler(), # Масштабирование данных\n",
" neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)\n",
" )\n",
" },\n",
" # Случайный лес (Random Forest)\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
" # Нейронная сеть (MLPRegressor)\n",
" \"mlp\": {\n",
" \"model\": make_pipeline(\n",
" StandardScaler(), # Масштабирование данных\n",
" neural_network.MLPRegressor(\n",
" activation=\"tanh\",\n",
" hidden_layer_sizes=(3,), # Простая архитектура сети\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" )\n",
" },\n",
"}\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование набора моделей для регрессии"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: linear\n",
"Model: linear_poly\n",
"Model: linear_interact\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: random_forest\n",
"Model: mlp\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\TIGR228\\Desktop\\МИИ\\Lab1\\AIM-PIbd-31-Afanasev-S-S\\aimenv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:690: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (500) reached and the optimization hasn't converged yet.\n",
" warnings.warn(\n"
]
}
],
"source": [
"import math\n",
"from pandas import DataFrame\n",
"from sklearn import metrics\n",
"\n",
"for model_name in models.keys():\n",
" print(f\"Model: {model_name}\")\n",
"\n",
" fitted_model = models[model_name][\"model\"].fit(\n",
" X_train.values, y_train.values.ravel()\n",
" )\n",
" y_train_pred = fitted_model.predict(X_train.values)\n",
" y_test_pred = fitted_model.predict(X_test.values)\n",
" models[model_name][\"fitted\"] = fitted_model\n",
" models[model_name][\"train_preds\"] = y_train_pred\n",
" models[model_name][\"preds\"] = y_test_pred\n",
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_train, y_train_pred)\n",
" )\n",
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод результатов оценки"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_da57c_row0_col0, #T_da57c_row0_col1, #T_da57c_row1_col0, #T_da57c_row1_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_da57c_row0_col2, #T_da57c_row1_col2, #T_da57c_row7_col3 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_da57c_row0_col3, #T_da57c_row1_col3, #T_da57c_row2_col3, #T_da57c_row3_col3, #T_da57c_row7_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_da57c_row2_col0 {\n",
" background-color: #25838e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_da57c_row2_col1 {\n",
" background-color: #25858e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_da57c_row2_col2 {\n",
" background-color: #7100a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_da57c_row3_col0 {\n",
" background-color: #25848e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_da57c_row3_col1 {\n",
" background-color: #24878e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_da57c_row3_col2 {\n",
" background-color: #7501a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_da57c_row4_col0 {\n",
" background-color: #23888e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_da57c_row4_col1 {\n",
" background-color: #23898e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_da57c_row4_col2 {\n",
" background-color: #7a02a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_da57c_row4_col3, #T_da57c_row6_col2 {\n",
" background-color: #d9586a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_da57c_row5_col0, #T_da57c_row5_col1 {\n",
" background-color: #89d548;\n",
" color: #000000;\n",
"}\n",
"#T_da57c_row5_col2 {\n",
" background-color: #d24f71;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_da57c_row5_col3 {\n",
" background-color: #7201a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_da57c_row6_col0, #T_da57c_row7_col0, #T_da57c_row7_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_da57c_row6_col1 {\n",
" background-color: #a2da37;\n",
" color: #000000;\n",
"}\n",
"#T_da57c_row6_col3 {\n",
" background-color: #5302a3;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_da57c\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_da57c_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
" <th id=\"T_da57c_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
" <th id=\"T_da57c_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
" <th id=\"T_da57c_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_da57c_level0_row0\" class=\"row_heading level0 row0\" >linear</th>\n",
" <td id=\"T_da57c_row0_col0\" class=\"data row0 col0\" >0.000000</td>\n",
" <td id=\"T_da57c_row0_col1\" class=\"data row0 col1\" >0.000000</td>\n",
" <td id=\"T_da57c_row0_col2\" class=\"data row0 col2\" >0.000000</td>\n",
" <td id=\"T_da57c_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_da57c_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_da57c_row1_col0\" class=\"data row1 col0\" >0.002409</td>\n",
" <td id=\"T_da57c_row1_col1\" class=\"data row1 col1\" >0.002181</td>\n",
" <td id=\"T_da57c_row1_col2\" class=\"data row1 col2\" >0.040847</td>\n",
" <td id=\"T_da57c_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_da57c_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
" <td id=\"T_da57c_row2_col0\" class=\"data row2 col0\" >1.856691</td>\n",
" <td id=\"T_da57c_row2_col1\" class=\"data row2 col1\" >3.387390</td>\n",
" <td id=\"T_da57c_row2_col2\" class=\"data row2 col2\" >1.612410</td>\n",
" <td id=\"T_da57c_row2_col3\" class=\"data row2 col3\" >0.967338</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_da57c_level0_row3\" class=\"row_heading level0 row3\" >decision_tree</th>\n",
" <td id=\"T_da57c_row3_col0\" class=\"data row3 col0\" >2.409987</td>\n",
" <td id=\"T_da57c_row3_col1\" class=\"data row3 col1\" >4.281378</td>\n",
" <td id=\"T_da57c_row3_col2\" class=\"data row3 col2\" >1.847379</td>\n",
" <td id=\"T_da57c_row3_col3\" class=\"data row3 col3\" >0.947823</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_da57c_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
" <td id=\"T_da57c_row4_col0\" class=\"data row4 col0\" >5.170083</td>\n",
" <td id=\"T_da57c_row4_col1\" class=\"data row4 col1\" >5.824576</td>\n",
" <td id=\"T_da57c_row4_col2\" class=\"data row4 col2\" >2.073599</td>\n",
" <td id=\"T_da57c_row4_col3\" class=\"data row4 col3\" >0.903431</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_da57c_level0_row5\" class=\"row_heading level0 row5\" >mlp</th>\n",
" <td id=\"T_da57c_row5_col0\" class=\"data row5 col0\" >60.594447</td>\n",
" <td id=\"T_da57c_row5_col1\" class=\"data row5 col1\" >59.994001</td>\n",
" <td id=\"T_da57c_row5_col2\" class=\"data row5 col2\" >7.553087</td>\n",
" <td id=\"T_da57c_row5_col3\" class=\"data row5 col3\" >-9.245313</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_da57c_level0_row6\" class=\"row_heading level0 row6\" >linear_poly</th>\n",
" <td id=\"T_da57c_row6_col0\" class=\"data row6 col0\" >67.720820</td>\n",
" <td id=\"T_da57c_row6_col1\" class=\"data row6 col1\" >66.518485</td>\n",
" <td id=\"T_da57c_row6_col2\" class=\"data row6 col2\" >8.144355</td>\n",
" <td id=\"T_da57c_row6_col3\" class=\"data row6 col3\" >-11.594887</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_da57c_level0_row7\" class=\"row_heading level0 row7\" >linear_interact</th>\n",
" <td id=\"T_da57c_row7_col0\" class=\"data row7 col0\" >67.518306</td>\n",
" <td id=\"T_da57c_row7_col1\" class=\"data row7 col1\" >67.518306</td>\n",
" <td id=\"T_da57c_row7_col2\" class=\"data row7 col2\" >8.216952</td>\n",
" <td id=\"T_da57c_row7_col3\" class=\"data row7 col3\" >-11.976353</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1d7b697bd10>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод реального и \"спрогнозированного\" результата для обучающей и тестовой выборок\n",
"\n",
"Получение лучшей модели"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'linear'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 27 candidates, totalling 81 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\TIGR228\\Desktop\\МИИ\\Lab1\\AIM-PIbd-31-Afanasev-S-S\\aimenv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие параметры: {'max_depth': 20, 'min_samples_split': 5, 'n_estimators': 150}\n",
"Лучший результат (MSE): 11.258082426680579\n",
"Ошибка на тестовой выборке (MSE): 9.1015\n",
"Коэффициент детерминации (R²): 0.9741\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# Удаление пропущенных значений (если есть)\n",
"df.dropna(inplace=True)\n",
"\n",
"# Предикторы и целевая переменная\n",
"X = df[[\"Pregnancies\", \"Glucose\", \"BloodPressure\", \"SkinThickness\", \"Insulin\", \n",
" \"BMI\", \"DiabetesPedigreeFunction\", \"Age\"]] \n",
"y = df[\"diabetes_risk_index\"] \n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y, test_size=0.2, random_state=42\n",
")\n",
"\n",
"# Определение модели случайного леса\n",
"model = RandomForestRegressor(random_state=42)\n",
"\n",
"# Сетка гиперпараметров для поиска\n",
"param_grid = {\n",
" 'n_estimators': [50, 100, 150], \n",
" 'max_depth': [10, 20, 30], \n",
" 'min_samples_split': [5, 10, 15] \n",
"}\n",
"\n",
"# Настройка поиска по сетке\n",
"grid_search = GridSearchCV(\n",
" estimator=model,\n",
" param_grid=param_grid,\n",
" scoring='neg_mean_squared_error', \n",
" cv=3, \n",
" n_jobs=-1, \n",
" verbose=2 \n",
")\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"grid_search.fit(X_train, y_train)\n",
"\n",
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
"print(\"Лучший результат (MSE):\", -grid_search.best_score_)\n",
"\n",
"# Оценка на тестовой выборке\n",
"best_model = grid_search.best_estimator_\n",
"y_pred = best_model.predict(X_test)\n",
"\n",
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
"r2 = metrics.r2_score(y_test, y_pred)\n",
"\n",
"print(f\"Ошибка на тестовой выборке (MSE): {mse:.4f}\")\n",
"print(f\"Коэффициент детерминации (R²): {r2:.4f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами и сравнение новых и старых данных"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 8 candidates, totalling 24 fits\n",
"Старые параметры: {'max_depth': 20, 'min_samples_split': 5, 'n_estimators': 50}\n",
"Лучший результат (MSE) на старых параметрах: 11.60966872499276\n",
"\n",
"Новые параметры: {'max_depth': 20, 'min_samples_split': 10, 'n_estimators': 100}\n",
"Лучший результат (MSE) на новых параметрах: 19.147439110906816\n",
"Среднеквадратическая ошибка (MSE) на тестовых данных: 10.814392376125927\n",
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 3.288524346287545\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"old_param_grid = {\n",
" 'n_estimators': [50, 100], # Количество деревьев\n",
" 'max_depth': [ 10, 20], # Максимальная глубина дерева\n",
" 'min_samples_split': [5, 10] # Минимальное количество образцов для разбиения узла\n",
"}\n",
"\n",
"old_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
" param_grid=old_param_grid,\n",
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
"\n",
"old_grid_search.fit(X_train, y_train)\n",
"\n",
"old_best_params = old_grid_search.best_params_\n",
"old_best_mse = -old_grid_search.best_score_ # Меняем знак, так как берем отрицательное значение MSE\n",
"\n",
"new_param_grid = {\n",
" 'n_estimators': [100],\n",
" 'max_depth': [20],\n",
" 'min_samples_split': [10]\n",
"}\n",
"\n",
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
" param_grid=new_param_grid,\n",
" scoring='neg_mean_squared_error', cv=2)\n",
"\n",
"new_grid_search.fit(X_train, y_train)\n",
"\n",
"new_best_params = new_grid_search.best_params_\n",
"new_best_mse = -new_grid_search.best_score_ # Меняем знак, так как берем отрицательное значение MSE\n",
"\n",
"model_best = RandomForestRegressor(**new_best_params)\n",
"model_best.fit(X_train, y_train)\n",
"\n",
"model_oldbest = RandomForestRegressor(**old_best_params)\n",
"model_oldbest.fit(X_train, y_train)\n",
"\n",
"y_pred = model_best.predict(X_test)\n",
"y_oldpred = model_oldbest.predict(X_test)\n",
"\n",
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"\n",
"print(\"Старые параметры:\", old_best_params)\n",
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
"print(\"\\nНовые параметры:\", new_best_params)\n",
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Попробуем визуализировать\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1wAAAHWCAYAAABjUYhTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeXhTZdr48e9J0yZtmqSU7rShBUqbIiIFFRR3RxG3VkdG5VV0dNCZ0b6OG+MoKuM6M+pgdVxwXEcd9/Z1mXF+7gsjKFAUbSoIlLSl0KZLmqYkTZrz+yM0ElqwlJZ0uT/X1QuSc3rynJM0yf0893M/iqqqKkIIIYQQQgghBpwm0g0QQgghhBBCiJFKAi4hhBBCCCGEGCQScAkhhBBCCCHEIJGASwghhBBCCCEGiQRcQgghhBBCCDFIJOASQgghhBBCiEEiAZcQQgghhBBCDBIJuIQQQgghhBBikEjAJYQQQgghhBCDRAIuIYQQgyI7O5tLLrkk0s0QQgghIkoCLiHEqPXMM8+gKAqrV6/udfvxxx/PIYcccpBbNTp5PB7++te/cuSRR2I2m9Hr9UyePJmrrrqKDRs2RLp5EXH77bejKMo+f84444xIN1MIIcRP0Ea6AUIIIUY3h8PB3LlzWbNmDWeccQYXXngh8fHxfP/997z00kssX76czs7OSDczYh599FHi4+N73P+73/0uAq0RQgixvyTgEkIIEVGXXHIJFRUVvPbaa5x77rlh2+644w5uvvnmCLVsaPj5z39OUlJSj/tvueWWCLRGCCHE/pKUQiGE2A9+v5877riDiRMnotPpyM7O5g9/+ANerzdsv+zs7FDal0ajIS0tjV/84hfY7fbQPtXV1SiKwn333bfXx+tOK9vT888/z4wZM4iNjSUxMZHzzz+fmpqan2z/1q1b+c1vfkNeXh6xsbGMHTuW8847j+rq6rD9utMtV6xYwbXXXktycjIGg4Hi4mIaGxvD9lVVlTvvvJPMzEzi4uI44YQT+O67736yLQCrVq3inXfe4bLLLusRbAHodLrQ9bnkkkt+MsWu+zz+7//+j9NPP52MjAx0Oh0TJ07kjjvuoKurK+z43Wmja9as4aijjiI2NpacnBwee+yxsP0+/vhjFEXhtdde2+u5XHLJJWRnZ4fdFwgEWLZsGVOmTEGv15OamsoVV1xBS0tLn67P/rrvvvs46qijGDt2LLGxscyYMaPXNiuKwlVXXcULL7xAXl4eer2eGTNm8Omnn4btt7+vl5iYmB6vjy+++CL0/OyZvrtq1Srmzp2L2WwmLi6O4447jhUrVoS29yWt8uOPPwYG/7kUQoj+khEuIcSo53Q6cTgcPe73+Xw97rv88st59tln+fnPf851113HqlWruOeee7DZbJSVlYXte8wxx7Bo0SICgQDffvsty5YtY9u2bXz22WcH1N677rqLJUuWMH/+fC6//HIaGxt56KGHOPbYY6moqCAhIWGvv/vVV1/x3//+l/PPP5/MzEyqq6t59NFHOf7446msrCQuLi5s/6uvvpoxY8Zw2223UV1dzbJly7jqqqt4+eWXQ/vceuut3HnnncybN4958+axdu1aTjnllD6lAb755psAXHTRRT+57xVXXMHJJ58cun3RRRdRXFzMOeecE7ovOTkZCAYA8fHxXHvttcTHx/Phhx9y66230tbWxl/+8pew47a0tDBv3jzmz5/PBRdcwCuvvMKvf/1rYmJi+OUvf/mT7fqpNj/zzDNceumllJSUsGXLFh5++GEqKipYsWIF0dHRB3T8PT344IOcddZZLFiwgM7OTl566SXOO+883n77bU4//fSwfT/55BNefvllSkpK0Ol0PPLII8ydO5cvv/wyNHdxf18vUVFRPP/882Hpjk8//TR6vR6PxxO274cffshpp53GjBkzuO2229BoNDz99NOceOKJfPbZZxxxxBGcc845TJo0KfQ7v/vd77BarSxatCh0n9VqDf1/MJ9LIYToN1UIIUapp59+WgX2+TNlypTQ/uvWrVMB9fLLLw87zvXXX68C6ocffhi6b/z48erChQvD9rvwwgvVuLi40O0tW7aogPqXv/xlr2287bbb1N3fqqurq9WoqCj1rrvuCttv/fr1qlar7XH/njo6Onrc98UXX6iA+txzz4Xu6742J598shoIBEL3/+53v1OjoqLU1tZWVVVVtaGhQY2JiVFPP/30sP3+8Ic/qECPa7Cn4uJiFVBbWlr2uV9vAPW2227rdVtv53nFFVeocXFxqsfjCd133HHHqYB6//33h+7zer3qYYcdpqakpKidnZ2qqqrqRx99pALqq6++utf2LFy4UB0/fnzo9meffaYC6gsvvBC237vvvtvr/Xvqfu4bGxt73T5+/Hj19NNPD7tvz/Pu7OxUDznkEPXEE08Mu7/79b169erQfVu3blX1er1aXFy81+Op6r5fLxdccIE6derU0P1ut1s1mUzqhRdeqALqV199paqqqgYCATU3N1c99dRTw143HR0dak5Ojvqzn/1sr+e8t9fUYD6XQghxICSlUAgx6v3tb3/jvffe6/Fz6KGHhu33r3/9C4Brr7027P7rrrsOgHfeeSfsfq/Xi8PhoKGhgffee48PP/yQk046qcfjd3R04HA4aGlpQVXVfbb1jTfeIBAIMH/+fBwOR+gnLS2N3NxcPvroo33+fmxsbOj/Pp+PpqYmJk2aREJCAmvXru2x/6JFi8JSGo855hi6urrYunUrAO+//z6dnZ1cffXVYftdc801+2xHt7a2NgCMRmOf9u+r3c/T5XLhcDg45phj6OjooKqqKmxfrVbLFVdcEbodExPDFVdcQUNDA2vWrAnbt/tYra2tP9mGV199FbPZzM9+9rOw52rGjBnEx8f/5HPVH7ufd0tLC06nk2OOOabX53b27NnMmDEjdNtisXD22Wfzn//8J5R6ub+vl4suuoiqqqpQ6uDrr7+O2Wzu8bpft24dGzdu5MILL6SpqSl0bdxuNyeddBKffvopgUBgv89/sJ5LIYQ4EJJSKIQY9Y444ghmzpzZ4/4xY8aEpRpu3boVjUYTluIEkJaWRkJCQigI6fbSSy/x0ksvhW4ffvjh/P3vf+/xOLfddhu33XYbAHq9nhNPPJFly5aRm5vbY9+NGzeiqmqv24CfTFHbuXMn99xzD08//TR1dXVhAZ7T6eyxv8ViCbs9ZswYgNAcpO5z3rM9ycnJoX33xWQyAcEvv/tKhdxf3333HbfccgsffvhhKKjrtud5ZmRkYDAYwu6bPHkyEJxnN2vWrND9u6elxcfHc+aZZ/LXv/6V1NTUHm3YuHEjTqeTlJSUXtvY0NCwfyfVB2+//TZ33nkn69atC5tX2Ns8wN5eQ5MnT6ajo4PGxkbS0tL2+/WSnJzM6aefzlNPPcXMmTN56qmnWLhwIRpNeP/uxo0bAVi4cOFez8XpdPbpNbS7wXouhRDiQEjAJYQQ+6m3L6+9OeWUU7jhhhsAqK2t5U9/+hMnnHACq1evDhs5WLRoEeeddx5dXV3YbDZuv/12ioqKei08EQgEUBSFf//730RFRfXY3lv58N1dffXVPP3001xzzTXMnj0bs9mMoiicf/75vY4o9PYYwE+OxPVVfn4+AOvXr+eYY44ZkGO2trZy3HHHYTKZ+OMf/8jEiRPR6/WsXbuWxYsX92vkpNutt97KMcccg8/nY82aNfzxj3+ktbU1NPq5u0AgQEpKCi+88EKvx+qebzZQPvvsM8466yyOPfZYHnnkEdLT04mOjubpp5/mxRdf7Ncx9/f1AsFA5uKLL+bqq6/m008/5e9//3uPeYvdv/uXv/yFww47rNfj/NRr+UDtz3MphBAHQgIuIYToo/HjxxMIBNi4cWPYRP0dO3bQ2trK+PHjw/ZPT08PK/KQl5fHUUcdRXl5ORdccEHo/tzc3NB+p556Kh0dHdx8881hFQ27TZw4EVVVycnJCfXc74/XXnuNhQsXcv/994fu83g8/U6r6j7njRs3MmHChND9jY2NfarEd+aZZ3LPPffw/PPPD1jA9fHHH9PU1MQbb7zBscceG7p
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10, 5))\n",
"plt.scatter(range(len(y_test)), y_test, label=\"Актуальные значения\", color=\"black\", alpha=0.5)\n",
"plt.scatter(range(len(y_test)), y_pred, label=\"новые параметры\", color=\"blue\", alpha=0.5)\n",
"plt.scatter(range(len(y_test)), y_oldpred, label=\"старые параметры\", color=\"red\", alpha=0.5)\n",
"plt.xlabel(\"Выборка\")\n",
"plt.ylabel(\"Значения\")\n",
"plt.legend()\n",
"plt.title(\"Новые and Старые Параметры\")\n",
"plt.show()"
]
2024-11-16 00:44:40 +04:00
}
],
"metadata": {
"kernelspec": {
"display_name": "aimenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}