MII/mai/lab4.ipynb

2887 lines
258 KiB
Plaintext
Raw Normal View History

2024-12-14 15:49:48 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Лабораторная работа 4\n",
"\n",
"Датасет - **Цены на бриллианты**\thttps://www.kaggle.com/datasets/nancyalaswad90/diamonds-prices\n",
"\n",
"1. **carat**: Вес бриллианта в каратах\n",
"2. **cut**: Качество огранки.\n",
"3. **color**: Цвет бриллианта\n",
"4. **clarity**: Чистота бриллианта\n",
"5. **depth**: Процент глубины бриллианта\n",
"6. **table**: Процент ширины бриллианта\n",
"7. **price**: Цена бриллианта в долларах США\n",
"8. **x**: Длина бриллианта в миллиметрах\n",
"9. **y**: Ширина бриллианта в миллиметрах\n",
"10. **z**: Глубина бриллианта в миллиметрах"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Бизнес-цели**: \n",
"1. Прогнозирование цены бриллиантов на основании характеристик.\n",
"2. Анализ частотности и сочетания характеристик бриллиантов, которые пользуются наибольшим спросом, чтобы лучше планировать запасы. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Загрузка набора данных"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Среднее значение поля 'карат': 0.7979346717831785\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>carat</th>\n",
" <th>cut</th>\n",
" <th>color</th>\n",
" <th>clarity</th>\n",
" <th>depth</th>\n",
" <th>table</th>\n",
" <th>price</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>z</th>\n",
" <th>above_average_carat</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.23</td>\n",
" <td>Ideal</td>\n",
" <td>E</td>\n",
" <td>SI2</td>\n",
" <td>61.5</td>\n",
" <td>55.0</td>\n",
" <td>326</td>\n",
" <td>3.95</td>\n",
" <td>3.98</td>\n",
" <td>2.43</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.21</td>\n",
" <td>Premium</td>\n",
" <td>E</td>\n",
" <td>SI1</td>\n",
" <td>59.8</td>\n",
" <td>61.0</td>\n",
" <td>326</td>\n",
" <td>3.89</td>\n",
" <td>3.84</td>\n",
" <td>2.31</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.23</td>\n",
" <td>Good</td>\n",
" <td>E</td>\n",
" <td>VS1</td>\n",
" <td>56.9</td>\n",
" <td>65.0</td>\n",
" <td>327</td>\n",
" <td>4.05</td>\n",
" <td>4.07</td>\n",
" <td>2.31</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.29</td>\n",
" <td>Premium</td>\n",
" <td>I</td>\n",
" <td>VS2</td>\n",
" <td>62.4</td>\n",
" <td>58.0</td>\n",
" <td>334</td>\n",
" <td>4.20</td>\n",
" <td>4.23</td>\n",
" <td>2.63</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0.31</td>\n",
" <td>Good</td>\n",
" <td>J</td>\n",
" <td>SI2</td>\n",
" <td>63.3</td>\n",
" <td>58.0</td>\n",
" <td>335</td>\n",
" <td>4.34</td>\n",
" <td>4.35</td>\n",
" <td>2.75</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53939</th>\n",
" <td>0.86</td>\n",
" <td>Premium</td>\n",
" <td>H</td>\n",
" <td>SI2</td>\n",
" <td>61.0</td>\n",
" <td>58.0</td>\n",
" <td>2757</td>\n",
" <td>6.15</td>\n",
" <td>6.12</td>\n",
" <td>3.74</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53940</th>\n",
" <td>0.75</td>\n",
" <td>Ideal</td>\n",
" <td>D</td>\n",
" <td>SI2</td>\n",
" <td>62.2</td>\n",
" <td>55.0</td>\n",
" <td>2757</td>\n",
" <td>5.83</td>\n",
" <td>5.87</td>\n",
" <td>3.64</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53941</th>\n",
" <td>0.71</td>\n",
" <td>Premium</td>\n",
" <td>E</td>\n",
" <td>SI1</td>\n",
" <td>60.5</td>\n",
" <td>55.0</td>\n",
" <td>2756</td>\n",
" <td>5.79</td>\n",
" <td>5.74</td>\n",
" <td>3.49</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53942</th>\n",
" <td>0.71</td>\n",
" <td>Premium</td>\n",
" <td>F</td>\n",
" <td>SI1</td>\n",
" <td>59.8</td>\n",
" <td>62.0</td>\n",
" <td>2756</td>\n",
" <td>5.74</td>\n",
" <td>5.73</td>\n",
" <td>3.43</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53943</th>\n",
" <td>0.70</td>\n",
" <td>Very Good</td>\n",
" <td>E</td>\n",
" <td>VS2</td>\n",
" <td>60.5</td>\n",
" <td>59.0</td>\n",
" <td>2757</td>\n",
" <td>5.71</td>\n",
" <td>5.76</td>\n",
" <td>3.47</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>53943 rows × 11 columns</p>\n",
"</div>"
],
"text/plain": [
" carat cut color clarity depth table price x y z \\\n",
"id \n",
"1 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43 \n",
"2 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31 \n",
"3 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31 \n",
"4 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63 \n",
"5 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75 \n",
"... ... ... ... ... ... ... ... ... ... ... \n",
"53939 0.86 Premium H SI2 61.0 58.0 2757 6.15 6.12 3.74 \n",
"53940 0.75 Ideal D SI2 62.2 55.0 2757 5.83 5.87 3.64 \n",
"53941 0.71 Premium E SI1 60.5 55.0 2756 5.79 5.74 3.49 \n",
"53942 0.71 Premium F SI1 59.8 62.0 2756 5.74 5.73 3.43 \n",
"53943 0.70 Very Good E VS2 60.5 59.0 2757 5.71 5.76 3.47 \n",
"\n",
" above_average_carat \n",
"id \n",
"1 0 \n",
"2 0 \n",
"3 0 \n",
"4 0 \n",
"5 0 \n",
"... ... \n",
"53939 1 \n",
"53940 0 \n",
"53941 0 \n",
"53942 0 \n",
"53943 0 \n",
"\n",
"[53943 rows x 11 columns]"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"df = pd.read_csv(\"data/Diamonds.csv\", index_col=\"id\")\n",
"\n",
"random_state=42\n",
"#считаем средний вес бриллиантов\n",
"average_carat = df['carat'].mean()\n",
"\n",
"print(f\"Среднее значение поля 'карат': {average_carat}\")\n",
"\n",
"#новый столбец, в котором 1 значит больше ср знач, 0 меньше\n",
"average_carat = df['carat'].mean()\n",
"df['above_average_carat'] = (df['carat'] > average_carat).astype(int)\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
"\n",
"Целевой признак -- Cut"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>carat</th>\n",
" <th>cut</th>\n",
" <th>color</th>\n",
" <th>clarity</th>\n",
" <th>depth</th>\n",
" <th>table</th>\n",
" <th>price</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>z</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>38836</th>\n",
" <td>0.40</td>\n",
" <td>Very Good</td>\n",
" <td>F</td>\n",
" <td>VVS2</td>\n",
" <td>62.0</td>\n",
" <td>56.0</td>\n",
" <td>1049</td>\n",
" <td>4.71</td>\n",
" <td>4.74</td>\n",
" <td>2.93</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30260</th>\n",
" <td>0.40</td>\n",
" <td>Very Good</td>\n",
" <td>E</td>\n",
" <td>SI1</td>\n",
" <td>63.0</td>\n",
" <td>57.0</td>\n",
" <td>725</td>\n",
" <td>4.68</td>\n",
" <td>4.71</td>\n",
" <td>2.96</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33169</th>\n",
" <td>0.36</td>\n",
" <td>Ideal</td>\n",
" <td>E</td>\n",
" <td>VS1</td>\n",
" <td>61.8</td>\n",
" <td>56.0</td>\n",
" <td>817</td>\n",
" <td>4.55</td>\n",
" <td>4.58</td>\n",
" <td>2.82</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1029</th>\n",
" <td>0.70</td>\n",
" <td>Very Good</td>\n",
" <td>E</td>\n",
" <td>VS1</td>\n",
" <td>58.4</td>\n",
" <td>59.0</td>\n",
" <td>2904</td>\n",
" <td>5.83</td>\n",
" <td>5.91</td>\n",
" <td>3.43</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53809</th>\n",
" <td>0.81</td>\n",
" <td>Very Good</td>\n",
" <td>G</td>\n",
" <td>SI1</td>\n",
" <td>60.7</td>\n",
" <td>56.0</td>\n",
" <td>2733</td>\n",
" <td>6.06</td>\n",
" <td>6.09</td>\n",
" <td>3.69</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2937</th>\n",
" <td>0.77</td>\n",
" <td>Good</td>\n",
" <td>E</td>\n",
" <td>VS2</td>\n",
" <td>63.4</td>\n",
" <td>57.0</td>\n",
" <td>3291</td>\n",
" <td>5.80</td>\n",
" <td>5.84</td>\n",
" <td>3.69</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7514</th>\n",
" <td>0.90</td>\n",
" <td>Good</td>\n",
" <td>F</td>\n",
" <td>SI1</td>\n",
" <td>61.8</td>\n",
" <td>63.0</td>\n",
" <td>4241</td>\n",
" <td>6.21</td>\n",
" <td>6.18</td>\n",
" <td>3.83</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48344</th>\n",
" <td>0.56</td>\n",
" <td>Ideal</td>\n",
" <td>H</td>\n",
" <td>VVS1</td>\n",
" <td>62.1</td>\n",
" <td>53.8</td>\n",
" <td>1961</td>\n",
" <td>5.27</td>\n",
" <td>5.33</td>\n",
" <td>3.29</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3212</th>\n",
" <td>0.70</td>\n",
" <td>Premium</td>\n",
" <td>F</td>\n",
" <td>VVS1</td>\n",
" <td>61.8</td>\n",
" <td>60.0</td>\n",
" <td>3348</td>\n",
" <td>5.67</td>\n",
" <td>5.63</td>\n",
" <td>3.49</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35654</th>\n",
" <td>0.31</td>\n",
" <td>Very Good</td>\n",
" <td>G</td>\n",
" <td>VVS2</td>\n",
" <td>63.1</td>\n",
" <td>57.0</td>\n",
" <td>907</td>\n",
" <td>4.32</td>\n",
" <td>4.30</td>\n",
" <td>2.72</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>43154 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" carat cut color clarity depth table price x y z\n",
"id \n",
"38836 0.40 Very Good F VVS2 62.0 56.0 1049 4.71 4.74 2.93\n",
"30260 0.40 Very Good E SI1 63.0 57.0 725 4.68 4.71 2.96\n",
"33169 0.36 Ideal E VS1 61.8 56.0 817 4.55 4.58 2.82\n",
"1029 0.70 Very Good E VS1 58.4 59.0 2904 5.83 5.91 3.43\n",
"53809 0.81 Very Good G SI1 60.7 56.0 2733 6.06 6.09 3.69\n",
"... ... ... ... ... ... ... ... ... ... ...\n",
"2937 0.77 Good E VS2 63.4 57.0 3291 5.80 5.84 3.69\n",
"7514 0.90 Good F SI1 61.8 63.0 4241 6.21 6.18 3.83\n",
"48344 0.56 Ideal H VVS1 62.1 53.8 1961 5.27 5.33 3.29\n",
"3212 0.70 Premium F VVS1 61.8 60.0 3348 5.67 5.63 3.49\n",
"35654 0.31 Very Good G VVS2 63.1 57.0 907 4.32 4.30 2.72\n",
"\n",
"[43154 rows x 10 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"id\n",
"38836 0\n",
"30260 0\n",
"33169 0\n",
"1029 0\n",
"53809 1\n",
" ..\n",
"2937 0\n",
"7514 1\n",
"48344 0\n",
"3212 0\n",
"35654 0\n",
"Name: above_average_carat, Length: 43154, dtype: int64"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>carat</th>\n",
" <th>cut</th>\n",
" <th>color</th>\n",
" <th>clarity</th>\n",
" <th>depth</th>\n",
" <th>table</th>\n",
" <th>price</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>z</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>32452</th>\n",
" <td>0.39</td>\n",
" <td>Very Good</td>\n",
" <td>E</td>\n",
" <td>VS2</td>\n",
" <td>60.9</td>\n",
" <td>58.0</td>\n",
" <td>793</td>\n",
" <td>4.72</td>\n",
" <td>4.77</td>\n",
" <td>2.89</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2432</th>\n",
" <td>0.72</td>\n",
" <td>Very Good</td>\n",
" <td>E</td>\n",
" <td>SI1</td>\n",
" <td>63.3</td>\n",
" <td>56.0</td>\n",
" <td>3183</td>\n",
" <td>5.67</td>\n",
" <td>5.71</td>\n",
" <td>3.60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16456</th>\n",
" <td>1.21</td>\n",
" <td>Ideal</td>\n",
" <td>H</td>\n",
" <td>SI1</td>\n",
" <td>62.1</td>\n",
" <td>59.0</td>\n",
" <td>6573</td>\n",
" <td>6.81</td>\n",
" <td>6.75</td>\n",
" <td>4.21</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46045</th>\n",
" <td>0.56</td>\n",
" <td>Ideal</td>\n",
" <td>D</td>\n",
" <td>SI1</td>\n",
" <td>62.5</td>\n",
" <td>56.0</td>\n",
" <td>1729</td>\n",
" <td>5.28</td>\n",
" <td>5.24</td>\n",
" <td>3.29</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11115</th>\n",
" <td>1.00</td>\n",
" <td>Good</td>\n",
" <td>E</td>\n",
" <td>SI1</td>\n",
" <td>62.4</td>\n",
" <td>59.0</td>\n",
" <td>4936</td>\n",
" <td>6.35</td>\n",
" <td>6.40</td>\n",
" <td>3.98</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40250</th>\n",
" <td>0.50</td>\n",
" <td>Premium</td>\n",
" <td>F</td>\n",
" <td>SI1</td>\n",
" <td>59.6</td>\n",
" <td>61.0</td>\n",
" <td>1125</td>\n",
" <td>5.15</td>\n",
" <td>5.12</td>\n",
" <td>3.06</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3308</th>\n",
" <td>0.73</td>\n",
" <td>Ideal</td>\n",
" <td>E</td>\n",
" <td>VS1</td>\n",
" <td>62.3</td>\n",
" <td>56.0</td>\n",
" <td>3370</td>\n",
" <td>5.75</td>\n",
" <td>5.80</td>\n",
" <td>3.60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7894</th>\n",
" <td>1.12</td>\n",
" <td>Very Good</td>\n",
" <td>I</td>\n",
" <td>SI1</td>\n",
" <td>60.6</td>\n",
" <td>60.0</td>\n",
" <td>4312</td>\n",
" <td>6.73</td>\n",
" <td>6.77</td>\n",
" <td>4.09</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21368</th>\n",
" <td>0.36</td>\n",
" <td>Ideal</td>\n",
" <td>D</td>\n",
" <td>SI1</td>\n",
" <td>62.2</td>\n",
" <td>53.0</td>\n",
" <td>626</td>\n",
" <td>4.57</td>\n",
" <td>4.59</td>\n",
" <td>2.85</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46144</th>\n",
" <td>0.50</td>\n",
" <td>Premium</td>\n",
" <td>E</td>\n",
" <td>VS2</td>\n",
" <td>61.3</td>\n",
" <td>59.0</td>\n",
" <td>1746</td>\n",
" <td>5.10</td>\n",
" <td>5.05</td>\n",
" <td>3.11</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>10789 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" carat cut color clarity depth table price x y z\n",
"id \n",
"32452 0.39 Very Good E VS2 60.9 58.0 793 4.72 4.77 2.89\n",
"2432 0.72 Very Good E SI1 63.3 56.0 3183 5.67 5.71 3.60\n",
"16456 1.21 Ideal H SI1 62.1 59.0 6573 6.81 6.75 4.21\n",
"46045 0.56 Ideal D SI1 62.5 56.0 1729 5.28 5.24 3.29\n",
"11115 1.00 Good E SI1 62.4 59.0 4936 6.35 6.40 3.98\n",
"... ... ... ... ... ... ... ... ... ... ...\n",
"40250 0.50 Premium F SI1 59.6 61.0 1125 5.15 5.12 3.06\n",
"3308 0.73 Ideal E VS1 62.3 56.0 3370 5.75 5.80 3.60\n",
"7894 1.12 Very Good I SI1 60.6 60.0 4312 6.73 6.77 4.09\n",
"21368 0.36 Ideal D SI1 62.2 53.0 626 4.57 4.59 2.85\n",
"46144 0.50 Premium E VS2 61.3 59.0 1746 5.10 5.05 3.11\n",
"\n",
"[10789 rows x 10 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"id\n",
"32452 0\n",
"2432 0\n",
"16456 1\n",
"46045 0\n",
"11115 1\n",
" ..\n",
"40250 0\n",
"3308 0\n",
"7894 1\n",
"21368 0\n",
"46144 0\n",
"Name: above_average_carat, Length: 10789, dtype: int64"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \"\"\"\n",
" Splits a Pandas dataframe into three subsets (train, val, and test)\n",
" following fractional ratios provided by the user, where each subset is\n",
" stratified by the values in a specific column (that is, each subset has\n",
" the same relative frequency of the values in the column). It performs this\n",
" splitting by running train_test_split() twice.\n",
" Parameters\n",
" ----------\n",
" df_input : Pandas dataframe\n",
" Input dataframe to be split.\n",
" stratify_colname : str\n",
" The name of the column that will be used for stratification. Usually\n",
" this column would be for the label.\n",
" frac_train : float\n",
" frac_val : float\n",
" frac_test : float\n",
" The ratios with which the dataframe will be split into train, val, and\n",
" test data. The values should be expressed as float fractions and should\n",
" sum to 1.0.\n",
" random_state : int, None, or RandomStateInstance\n",
" Value to be passed to train_test_split().\n",
" Returns\n",
" -------\n",
" df_train, df_val, df_test :\n",
" Dataframes containing the three splits.\n",
" \"\"\"\n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" X = df_input # содержит все столбцы.\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ] # содержит столбец для стратификации.\n",
" X = df.drop(['above_average_carat'], axis=1)\n",
" y = df['above_average_carat']\n",
" # Первичное разбиение на обучающую и временную выборки.\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
" # Вторичное разбиение на валидационную и тестовую выборки.\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"#разбиение на 80% обучающей выборки и 20% тестовой выборки\n",
"#Стратификация выполняется по столбцу \"above_average_carat\"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"above_average_carat\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование конвейера для классификации данных\n",
"\n",
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
"\n",
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
"\n",
"features_preprocessing -- трансформер для предобработки признаков\n",
"\n",
"features_engineering -- трансформер для конструирования признаков\n",
"\n",
"drop_columns -- трансформер для удаления колонок\n",
"\n",
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"class DaimondFeatures(BaseEstimator, TransformerMixin):\n",
" def __init__(self):\n",
" pass\n",
" def fit(self, X, y=None):\n",
" return self\n",
" #добавляем новый столбец \"Length_to_Width_Ratio\" - отношение длины к ширине (столбцы x и y)\n",
" def transform(self, X, y=None):\n",
" X[\"Length_to_Width_Ratio\"] = X[\"x\"] / X[\"y\"]\n",
" return X\n",
" #добавляем имя нового столбца к входным именам\n",
" def get_feature_names_out(self, features_in):\n",
" return np.append(features_in, [\"Length_to_Width_Ratio\"], axis=0)\n",
" \n",
"#Список столбцов, которые будут удалены из данных\n",
"columns_to_drop = []\n",
"#Список числовых столбцов\n",
"num_columns = [\"carat\", \"depth\", \"table\", \"x\", \"y\", \"z\"]\n",
"#Список категориальных столбцов\n",
"cat_columns = [\"cut\", \"color\", \"clarity\"]\n",
"\n",
"#Предобработка числовых столбцов\n",
"#Используется для заполнения пропущенных значений в числовых столбцах медианой\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"#Стандартизирует числовые столбцы, приводя их к нулевому среднему и единичному стандартному отклонению\n",
"num_scaler = StandardScaler()\n",
"#Создается конвейер для последовательного применения SimpleImputer и StandardScaler\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"#Предобработка категориальных столбцов\n",
"#Заполняет пропущенные значения в категориальных столбцов значением \"unknown\"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"#Кодирует категориальные переменные с помощью one-hot кодирования\n",
"#handle_unknown=ignore позволяет игнорировать неизвестные категории\n",
"#drop=first исключает первую категорию для предотвращения мультиколлинеарности\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"#Объединение этапов предобработки\n",
"#ColumnTransformer Позволяет применять разные преобразования к различным колонкам\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" #числовые и категориальные признаки обрабатываются разными конвейерами\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" #все остальные столбцы, которые не указаны в трансформерах, будут переданы без изменений\n",
" remainder=\"passthrough\"\n",
")\n",
"#конструирование признаков\n",
"#создается еще один ColumnTransformer, который применяет класс DaimondFeatures к столбцам x и y \n",
"#Все остальные столбцы будут переданы без изменений\n",
"features_engineering = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"add_features\", DaimondFeatures(), [\"x\", \"y\"]),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"#удаление столбцов, в данном случае ничего не будет удалено\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"#Постобработка столбцов\n",
"#Создается еще один ColumnTransformer который обрабатывает столбец Cabin_type с помощью prepocessing_cat\n",
"features_postprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_cat\", preprocessing_cat, [\"Cabin_type\"]),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"#Финальный конвейер\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"features_engineering\", features_engineering),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Демонстрация работы конвейера для предобработки данных при классификации"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>Length_to_Width_Ratio</th>\n",
" <th>carat</th>\n",
" <th>depth</th>\n",
" <th>table</th>\n",
" <th>z</th>\n",
" <th>cut_Good</th>\n",
" <th>cut_Ideal</th>\n",
" <th>cut_Premium</th>\n",
" <th>...</th>\n",
" <th>color_I</th>\n",
" <th>color_J</th>\n",
" <th>clarity_IF</th>\n",
" <th>clarity_SI1</th>\n",
" <th>clarity_SI2</th>\n",
" <th>clarity_VS1</th>\n",
" <th>clarity_VS2</th>\n",
" <th>clarity_VVS1</th>\n",
" <th>clarity_VVS2</th>\n",
" <th>price</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>38836</th>\n",
" <td>-0.907744</td>\n",
" <td>-0.863476</td>\n",
" <td>1.051267</td>\n",
" <td>-0.837490</td>\n",
" <td>0.176170</td>\n",
" <td>-0.648004</td>\n",
" <td>-0.857040</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1049</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30260</th>\n",
" <td>-0.934483</td>\n",
" <td>-0.889579</td>\n",
" <td>1.050478</td>\n",
" <td>-0.837490</td>\n",
" <td>0.876071</td>\n",
" <td>-0.201125</td>\n",
" <td>-0.814688</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>725</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33169</th>\n",
" <td>-1.050350</td>\n",
" <td>-1.002691</td>\n",
" <td>1.047532</td>\n",
" <td>-0.921885</td>\n",
" <td>0.036190</td>\n",
" <td>-0.648004</td>\n",
" <td>-1.012333</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>817</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1029</th>\n",
" <td>0.090496</td>\n",
" <td>0.154530</td>\n",
" <td>0.585622</td>\n",
" <td>-0.204531</td>\n",
" <td>-2.343471</td>\n",
" <td>0.692631</td>\n",
" <td>-0.151165</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2904</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53809</th>\n",
" <td>0.295492</td>\n",
" <td>0.311147</td>\n",
" <td>0.949688</td>\n",
" <td>0.027554</td>\n",
" <td>-0.733700</td>\n",
" <td>-0.648004</td>\n",
" <td>0.215890</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2733</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2937</th>\n",
" <td>0.063758</td>\n",
" <td>0.093624</td>\n",
" <td>0.680999</td>\n",
" <td>-0.056841</td>\n",
" <td>1.156031</td>\n",
" <td>-0.201125</td>\n",
" <td>0.215890</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>3291</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7514</th>\n",
" <td>0.429185</td>\n",
" <td>0.389455</td>\n",
" <td>1.102015</td>\n",
" <td>0.217442</td>\n",
" <td>0.036190</td>\n",
" <td>2.480145</td>\n",
" <td>0.413535</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>4241</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48344</th>\n",
" <td>-0.408624</td>\n",
" <td>-0.350123</td>\n",
" <td>1.167088</td>\n",
" <td>-0.499912</td>\n",
" <td>0.246160</td>\n",
" <td>-1.631136</td>\n",
" <td>-0.348810</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1961</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3212</th>\n",
" <td>-0.052109</td>\n",
" <td>-0.089095</td>\n",
" <td>0.584874</td>\n",
" <td>-0.204531</td>\n",
" <td>0.036190</td>\n",
" <td>1.139510</td>\n",
" <td>-0.066460</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>3348</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35654</th>\n",
" <td>-1.255346</td>\n",
" <td>-1.246316</td>\n",
" <td>1.007245</td>\n",
" <td>-1.027378</td>\n",
" <td>0.946061</td>\n",
" <td>-0.201125</td>\n",
" <td>-1.153508</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>907</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>43154 rows × 25 columns</p>\n",
"</div>"
],
"text/plain": [
" x y Length_to_Width_Ratio carat depth \\\n",
"id \n",
"38836 -0.907744 -0.863476 1.051267 -0.837490 0.176170 \n",
"30260 -0.934483 -0.889579 1.050478 -0.837490 0.876071 \n",
"33169 -1.050350 -1.002691 1.047532 -0.921885 0.036190 \n",
"1029 0.090496 0.154530 0.585622 -0.204531 -2.343471 \n",
"53809 0.295492 0.311147 0.949688 0.027554 -0.733700 \n",
"... ... ... ... ... ... \n",
"2937 0.063758 0.093624 0.680999 -0.056841 1.156031 \n",
"7514 0.429185 0.389455 1.102015 0.217442 0.036190 \n",
"48344 -0.408624 -0.350123 1.167088 -0.499912 0.246160 \n",
"3212 -0.052109 -0.089095 0.584874 -0.204531 0.036190 \n",
"35654 -1.255346 -1.246316 1.007245 -1.027378 0.946061 \n",
"\n",
" table z cut_Good cut_Ideal cut_Premium ... color_I \\\n",
"id ... \n",
"38836 -0.648004 -0.857040 0.0 0.0 0.0 ... 0.0 \n",
"30260 -0.201125 -0.814688 0.0 0.0 0.0 ... 0.0 \n",
"33169 -0.648004 -1.012333 0.0 1.0 0.0 ... 0.0 \n",
"1029 0.692631 -0.151165 0.0 0.0 0.0 ... 0.0 \n",
"53809 -0.648004 0.215890 0.0 0.0 0.0 ... 0.0 \n",
"... ... ... ... ... ... ... ... \n",
"2937 -0.201125 0.215890 1.0 0.0 0.0 ... 0.0 \n",
"7514 2.480145 0.413535 1.0 0.0 0.0 ... 0.0 \n",
"48344 -1.631136 -0.348810 0.0 1.0 0.0 ... 0.0 \n",
"3212 1.139510 -0.066460 0.0 0.0 1.0 ... 0.0 \n",
"35654 -0.201125 -1.153508 0.0 0.0 0.0 ... 0.0 \n",
"\n",
" color_J clarity_IF clarity_SI1 clarity_SI2 clarity_VS1 \\\n",
"id \n",
"38836 0.0 0.0 0.0 0.0 0.0 \n",
"30260 0.0 0.0 1.0 0.0 0.0 \n",
"33169 0.0 0.0 0.0 0.0 1.0 \n",
"1029 0.0 0.0 0.0 0.0 1.0 \n",
"53809 0.0 0.0 1.0 0.0 0.0 \n",
"... ... ... ... ... ... \n",
"2937 0.0 0.0 0.0 0.0 0.0 \n",
"7514 0.0 0.0 1.0 0.0 0.0 \n",
"48344 0.0 0.0 0.0 0.0 0.0 \n",
"3212 0.0 0.0 0.0 0.0 0.0 \n",
"35654 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" clarity_VS2 clarity_VVS1 clarity_VVS2 price \n",
"id \n",
"38836 0.0 0.0 1.0 1049 \n",
"30260 0.0 0.0 0.0 725 \n",
"33169 0.0 0.0 0.0 817 \n",
"1029 0.0 0.0 0.0 2904 \n",
"53809 0.0 0.0 0.0 2733 \n",
"... ... ... ... ... \n",
"2937 1.0 0.0 0.0 3291 \n",
"7514 0.0 0.0 0.0 4241 \n",
"48344 0.0 1.0 0.0 1961 \n",
"3212 0.0 1.0 0.0 3348 \n",
"35654 0.0 0.0 1.0 907 \n",
"\n",
"[43154 rows x 25 columns]"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#Обучает все трансформеры в конвейере на данных обучающей выборки\n",
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" #возвращает имена всех столбцов, которые были созданы в результате всех этапов предобработки в конвейере\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование набора моделей для классификации"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"#классификационные модели\n",
"class_models = {\n",
" #Логистическая регрессия\n",
" # от 0 до 1, принадлежит ли объект к классу\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" #гребневая регрессия\n",
" #Логическая, но с регуляризацией L2 (модель не так точно запоминает данные) и сбалансированными весами классов\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" #дерево решений с максимальной глубиной 7\n",
" #Деления данных на условия с помощью построения дерева\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
" },\n",
" #K-ближайших соседей с количеством соседей, равным 7\n",
" #Определяет ближайших объектов и находит и класс\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" #Наивный байесовский классификатор - Вероятности для классификации\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" #Градиентный бустинг с 210 деревьями\n",
" #Постепенно улучшает предсказания с помощью слабых моделей\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" #Случайный лес с максимальной глубиной 11 и сбалансированными весами классов\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
" )\n",
" },\n",
" #Многослойный персептрон (нейронная сеть)с одним скрытым слоем\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" #содержажит 7 нейронов\n",
" hidden_layer_sizes=(7,),\n",
" #максимальное количеством итераций 500\n",
" max_iter=500,\n",
" #включение ранней остановки для предотвращения переобучения\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение моделей на обучающем наборе данных и оценка на тестовом"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\3 kurs\\МИИ\\1 лаб\\mai-main\\.venv\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" n_iter_i = _check_optimize_result(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: ridge\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\3 kurs\\МИИ\\1 лаб\\mai-main\\.venv\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" n_iter_i = _check_optimize_result(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"\n",
"#проходим по всем моделям\n",
"for model_name in class_models.keys():\n",
" #выводим названия\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" #Получение модели и создание конвейера\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" #Предсказания на обучающей и тестовой выборке\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" #Сохранение результатов в словаре моделей\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" #Оценка производительности модели\n",
" #точность\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" #полнота\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" #аккуратность (верность) - Доля правильных предсказаний среди всех\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" #Площадь под кривой ROC, которая показывает качество классификации\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" #Гармоническое среднее между точностью и полнотой\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
" #Коэффициент корреляции Мэтьюса, который учитывает все возможные результаты классификации\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" #Мера согласия между двумя классификаторами\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" #Матрица ошибок, показывающая количество истинных и ложных положительных и отрицательных предсказаний\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Сводная таблица оценок качества для использованных моделей классификации"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Матрица неточностей"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0cAAAQ9CAYAAACSpDaqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxU5f4H8M/MwAwIDIuyKiKKC7hvJbmmBpqmpuXNNHFPww1zyXI3pSxzy6XcTc2srFvupKCp6FVTU1RyQXEBXBAQZJuZ8/uDH6MTMDBycGDO531f53XhPM+ceWbS8/F7nrPIBEEQQEREREREJHFycw+AiIiIiIioPGBxREREREREBBZHREREREREAFgcERERERERAWBxREREREREBIDFEREREREREQAWR0RERERERABYHBEREREREQFgcURERERERASAxRE9pw0bNkAmk+HGjRtlsv0bN25AJpNhw4YNomwvKioKMpkMUVFRomyPiIjIUsyaNQsymaxEfWUyGWbNmlW2AyIyIxZHZFFWrFghWkFFRERERNJiZe4BEBXGx8cHmZmZsLa2Nul1K1asQJUqVTBo0CCD9e3atUNmZiaUSqWIoyQiIqr4pk2bho8++sjcwyAqF1gcUbkkk8lgY2Mj2vbkcrmo2yMiIrIEGRkZsLOzg5UV/0lIBPC0OhLRihUrUL9+fahUKnh5eSE0NBQpKSkF+i1fvhw1a9aEra0tXnrpJfz555/o0KEDOnTooO9T2DVHiYmJGDx4MKpVqwaVSgVPT0/07NlTf91TjRo1EBMTg0OHDkEmk0Emk+m3WdQ1RydOnMDrr78OZ2dn2NnZoVGjRliyZIm4XwwREVE5kH9t0cWLF/Huu+/C2dkZbdq0KfSao+zsbISFhcHV1RUODg7o0aMHbt++Xeh2o6Ki0KJFC9jY2KBWrVr45ptviryOafPmzWjevDlsbW3h4uKCd955B7du3SqTz0v0PHiYgEQxa9YszJ49G507d8aoUaMQGxuLlStX4uTJkzh69Kj+9LiVK1di9OjRaNu2LcLCwnDjxg306tULzs7OqFatmtH36NOnD2JiYjBmzBjUqFED9+7dQ0REBOLj41GjRg0sXrwYY8aMgb29PT755BMAgLu7e5Hbi4iIQPfu3eHp6Ylx48bBw8MDly5dws6dOzFu3DjxvhwiIqJy5O2330bt2rUxf/58CIKAe/fuFegzbNgwbN68Ge+++y5eeeUVHDx4EN26dSvQ78yZM+jSpQs8PT0xe/ZsaLVazJkzB66urgX6zps3D9OnT0ffvn0xbNgw3L9/H8uWLUO7du1w5swZODk5lcXHJTKNQPQc1q9fLwAQ4uLihHv37glKpVIICgoStFqtvs/XX38tABDWrVsnCIIgZGdnC5UrVxZatmwp5Obm6vtt2LBBACC0b99evy4uLk4AIKxfv14QBEF49OiRAED44osvjI6rfv36BtvJFxkZKQAQIiMjBUEQBI1GI/j6+go+Pj7Co0ePDPrqdLqSfxFEREQVxMyZMwUAQr9+/Qpdn+/s2bMCAOGDDz4w6Pfuu+8KAISZM2fq173xxhtCpUqVhDt37ujXXblyRbCysjLY5o0bNwSFQiHMmzfPYJvnz58XrKysCqwnMheeVkel9scffyAnJwfjx4+HXP70j9Tw4cOhVquxa9cuAMCpU6fw8OFDDB8+3ODc5v79+8PZ2dnoe9ja2kKpVCIqKgqPHj0q9ZjPnDmDuLg4jB8/vsCRqpLezpSIiKgiGjlypNH23bt3AwDGjh1rsH78+PEGv2u1Wvzxxx/o1asXvLy89Ov9/PzQtWtXg747duyATqdD37598eDBA/3i4eGB2rVrIzIyshSfiEg8PK2OSu3mzZsAgLp16xqsVyqVqFmzpr49///9/PwM+llZWaFGjRpG30OlUuHzzz/Hhx9+CHd3d7Rq1Qrdu3fHwIED4eHhYfKYr127BgBo0KCBya8lIiKqyHx9fY2237x5E3K5HLVq1TJY/++cv3fvHjIzMwvkOlAw669cuQJBEFC7du1C39PUu9MSlRUWR1RhjB8/Hm+88QZ+/fVX7Nu3D9OnT0d4eDgOHjyIpk2bmnt4REREFYKtre0Lf0+dTgeZTIY9e/ZAoVAUaLe3t3/hYyIqDE+ro1Lz8fEBAMTGxhqsz8nJQVxcnL49//+vXr1q0E+j0ejvOFecWrVq4cMPP8T+/ftx4cIF5OTkYOHChfr2kp4Sl3807MKFCyXqT0REJBU+Pj7Q6XT6syzy/Tvn3dzcYGNjUyDXgYJZX6tWLQiCAF9fX3Tu3LnA0qpVK/E/CNFzYHFEpda5c2colUosXboUgiDo169duxapqan6u9u0aNEClStXxurVq6HRaPT9tmzZUux1RE+ePEFWVpbBulq1asHBwQHZ2dn6dXZ2doXePvzfmjVrBl9fXyxevLhA/2c/AxERkdTkXy+0dOlSg/WLFy82+F2hUKBz58749ddfcffuXf36q1evYs+ePQZ9e/fuDYVCgdmzZxfIWUEQ8PDhQxE/AdHz42l1VGqurq6YOnUqZs+ejS5duqBHjx6IjY3FihUr0LJlSwwYMABA3jVIs2bNwpgxY9CxY0f07dsXN27cwIYNG1CrVi2jsz7//PMPOnXqhL59+yIgIABWVlb45ZdfkJSUhHfeeUffr3nz5li5ciU+/fRT+Pn5wc3NDR07diywPblcjpUrV+KNN95AkyZNMHjwYHh6euLy5cuIiYnBvn37xP+iiIiIKoAmTZqgX79+WLFiBVJTU/HKK6/gwIEDhc4QzZo1C/v370fr1q0xatQoaLVafP3112jQoAHOnj2r71erVi18+umnmDp1qv4xHg4ODoiLi8Mvv/yCESNGYOLEiS/wUxIVjsURiWLWrFlwdXXF119/jbCwMLi4uGDEiBGYP3++wUWWo0ePhiAIWLhwISZOnIjGjRvjt99+w9ixY2FjY1Pk9r29vdGvXz8cOHAA3333HaysrFCvXj1s374dffr00febMWMGbt68iQULFuDx48do3759ocURAAQHByMyMhKzZ8/GwoULodPpUKtWLQwfPly8L4aIiKgCWrduHVxdXbFlyxb8+uuv6NixI3bt2gVvb2+Dfs2bN8eePXswceJETJ8+Hd7e3pgzZw4uXbqEy5cvG/T96KOPUKdOHSxatAizZ88GkJfvQUFB6NGjxwv7bETGyASeQ0RmptPp4Orqit69e2P16tXmHg4RERGVUq9evRATE4MrV66YeyhEJuE1R/RCZWVlFTjXeNOmTUhOTkaHDh3MMygiIiJ6bpmZmQa/X7lyBbt372auU4XEmSN6oaKiohAWFoa3334blStXxl9//YW1a9fC398fp0+fhlKpNPcQiYiIyASenp4YNGiQ/tmGK1euRHZ2Ns6cOVPkc42Iyitec0QvVI0aNeDt7Y2lS5ciOTkZLi4uGDhwID777DMWRkRERBVQly5d8P333yMxMREqlQqBgYGYP38+CyOqkDhzREREREREBF5zREREREREBIDFEREREREREQBec1QiOp0Od+/ehYODg9EHlRJZIkEQ8PjxY3h5eUEuF/d4SlZWFnJycortp1QqjT4Hi4ikh9lMUsZsLjssjkrg7t27BR56RiQ1t27dQrVq1UTbXlZWFnx97JF4T1tsXw8PD8TFxVnkTpiIng+zmYjZXBZYHJWAg4MDAODmXzWgtueZiObwZp2G5h6CZGmQiyPYrf97IJacnBwk3tPi6ilvqB2K/nuV9lgHvxa3kJOTY3E7YCJ6fsxm82M2mw+zueywOCqB/Ol6tb3c6B8UKjtWMmtzD0G6/v9+lmV12oq9gwz2DkVvWwfT3/fOnTuYMmUK9uzZgydPnsDPzw/r169HixYtAOSdjjBz5kysXr0aKSkpaN26NVauXGlw29nk5GSMGTMGv//+O+RyOfr06YMlS5bA3t5e3+fvv/9GaGgoTp48CVdXV4wZMwaTJ082ebxEZDpms/kxm82I2Vxm2cy9CRGZVa6gLXYxxaNHj9C6dWtYW1tjz549uHjxIhYuXAhnZ2d9nwULFmDp0qV
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"#Функция, создающая фигуру и набор подграфиков (axes)\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Less\", \"More\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_b8c4d_row0_col0, #T_b8c4d_row0_col1, #T_b8c4d_row0_col2, #T_b8c4d_row0_col3, #T_b8c4d_row1_col0, #T_b8c4d_row1_col1, #T_b8c4d_row1_col2, #T_b8c4d_row1_col3, #T_b8c4d_row2_col0, #T_b8c4d_row2_col1, #T_b8c4d_row2_col2, #T_b8c4d_row2_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_b8c4d_row0_col4, #T_b8c4d_row0_col5, #T_b8c4d_row0_col6, #T_b8c4d_row0_col7, #T_b8c4d_row1_col4, #T_b8c4d_row1_col5, #T_b8c4d_row1_col6, #T_b8c4d_row1_col7, #T_b8c4d_row2_col4, #T_b8c4d_row2_col5, #T_b8c4d_row2_col6, #T_b8c4d_row2_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b8c4d_row3_col0 {\n",
" background-color: #7fd34e;\n",
" color: #000000;\n",
"}\n",
"#T_b8c4d_row3_col1, #T_b8c4d_row5_col0 {\n",
" background-color: #8ed645;\n",
" color: #000000;\n",
"}\n",
"#T_b8c4d_row3_col2 {\n",
" background-color: #90d743;\n",
" color: #000000;\n",
"}\n",
"#T_b8c4d_row3_col3, #T_b8c4d_row4_col1 {\n",
" background-color: #93d741;\n",
" color: #000000;\n",
"}\n",
"#T_b8c4d_row3_col4, #T_b8c4d_row3_col6 {\n",
" background-color: #ca457a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b8c4d_row3_col5, #T_b8c4d_row3_col7 {\n",
" background-color: #cf4c74;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b8c4d_row4_col0, #T_b8c4d_row4_col3 {\n",
" background-color: #86d549;\n",
" color: #000000;\n",
"}\n",
"#T_b8c4d_row4_col2 {\n",
" background-color: #6ece58;\n",
" color: #000000;\n",
"}\n",
"#T_b8c4d_row4_col4 {\n",
" background-color: #c43e7f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b8c4d_row4_col5, #T_b8c4d_row4_col7 {\n",
" background-color: #ce4b75;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b8c4d_row4_col6 {\n",
" background-color: #c5407e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b8c4d_row5_col1, #T_b8c4d_row6_col3 {\n",
" background-color: #9bd93c;\n",
" color: #000000;\n",
"}\n",
"#T_b8c4d_row5_col2 {\n",
" background-color: #5ac864;\n",
" color: #000000;\n",
"}\n",
"#T_b8c4d_row5_col3 {\n",
" background-color: #73d056;\n",
" color: #000000;\n",
"}\n",
"#T_b8c4d_row5_col4, #T_b8c4d_row5_col6 {\n",
" background-color: #c13b82;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b8c4d_row5_col5, #T_b8c4d_row5_col7 {\n",
" background-color: #cc4778;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b8c4d_row6_col0, #T_b8c4d_row6_col1, #T_b8c4d_row7_col2, #T_b8c4d_row7_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b8c4d_row6_col2 {\n",
" background-color: #9dd93b;\n",
" color: #000000;\n",
"}\n",
"#T_b8c4d_row6_col4 {\n",
" background-color: #6300a7;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b8c4d_row6_col5 {\n",
" background-color: #7100a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b8c4d_row6_col6 {\n",
" background-color: #6700a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b8c4d_row6_col7 {\n",
" background-color: #7501a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b8c4d_row7_col0 {\n",
" background-color: #2eb37c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b8c4d_row7_col1 {\n",
" background-color: #25ac82;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b8c4d_row7_col4, #T_b8c4d_row7_col5, #T_b8c4d_row7_col6, #T_b8c4d_row7_col7 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_b8c4d\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_b8c4d_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_b8c4d_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_b8c4d_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_b8c4d_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_b8c4d_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_b8c4d_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_b8c4d_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_b8c4d_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_b8c4d_level0_row0\" class=\"row_heading level0 row0\" >decision_tree</th>\n",
" <td id=\"T_b8c4d_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_b8c4d_level0_row1\" class=\"row_heading level0 row1\" >gradient_boosting</th>\n",
" <td id=\"T_b8c4d_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_b8c4d_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
" <td id=\"T_b8c4d_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
" <td id=\"T_b8c4d_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_b8c4d_level0_row3\" class=\"row_heading level0 row3\" >mlp</th>\n",
" <td id=\"T_b8c4d_row3_col0\" class=\"data row3 col0\" >0.992471</td>\n",
" <td id=\"T_b8c4d_row3_col1\" class=\"data row3 col1\" >0.994967</td>\n",
" <td id=\"T_b8c4d_row3_col2\" class=\"data row3 col2\" >0.996822</td>\n",
" <td id=\"T_b8c4d_row3_col3\" class=\"data row3 col3\" >0.996494</td>\n",
" <td id=\"T_b8c4d_row3_col4\" class=\"data row3 col4\" >0.995458</td>\n",
" <td id=\"T_b8c4d_row3_col5\" class=\"data row3 col5\" >0.996385</td>\n",
" <td id=\"T_b8c4d_row3_col6\" class=\"data row3 col6\" >0.994642</td>\n",
" <td id=\"T_b8c4d_row3_col7\" class=\"data row3 col7\" >0.995730</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_b8c4d_level0_row4\" class=\"row_heading level0 row4\" >ridge</th>\n",
" <td id=\"T_b8c4d_row4_col0\" class=\"data row4 col0\" >0.993907</td>\n",
" <td id=\"T_b8c4d_row4_col1\" class=\"data row4 col1\" >0.996049</td>\n",
" <td id=\"T_b8c4d_row4_col2\" class=\"data row4 col2\" >0.992164</td>\n",
" <td id=\"T_b8c4d_row4_col3\" class=\"data row4 col3\" >0.994521</td>\n",
" <td id=\"T_b8c4d_row4_col4\" class=\"data row4 col4\" >0.994114</td>\n",
" <td id=\"T_b8c4d_row4_col5\" class=\"data row4 col5\" >0.996014</td>\n",
" <td id=\"T_b8c4d_row4_col6\" class=\"data row4 col6\" >0.993035</td>\n",
" <td id=\"T_b8c4d_row4_col7\" class=\"data row4 col7\" >0.995285</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_b8c4d_level0_row5\" class=\"row_heading level0 row5\" >logistic</th>\n",
" <td id=\"T_b8c4d_row5_col0\" class=\"data row5 col0\" >0.995093</td>\n",
" <td id=\"T_b8c4d_row5_col1\" class=\"data row5 col1\" >0.997574</td>\n",
" <td id=\"T_b8c4d_row5_col2\" class=\"data row5 col2\" >0.989041</td>\n",
" <td id=\"T_b8c4d_row5_col3\" class=\"data row5 col3\" >0.991234</td>\n",
" <td id=\"T_b8c4d_row5_col4\" class=\"data row5 col4\" >0.993303</td>\n",
" <td id=\"T_b8c4d_row5_col5\" class=\"data row5 col5\" >0.995273</td>\n",
" <td id=\"T_b8c4d_row5_col6\" class=\"data row5 col6\" >0.992058</td>\n",
" <td id=\"T_b8c4d_row5_col7\" class=\"data row5 col7\" >0.994394</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_b8c4d_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
" <td id=\"T_b8c4d_row6_col0\" class=\"data row6 col0\" >0.946797</td>\n",
" <td id=\"T_b8c4d_row6_col1\" class=\"data row6 col1\" >0.942845</td>\n",
" <td id=\"T_b8c4d_row6_col2\" class=\"data row6 col2\" >0.998521</td>\n",
" <td id=\"T_b8c4d_row6_col3\" class=\"data row6 col3\" >0.997808</td>\n",
" <td id=\"T_b8c4d_row6_col4\" class=\"data row6 col4\" >0.975645</td>\n",
" <td id=\"T_b8c4d_row6_col5\" class=\"data row6 col5\" >0.973492</td>\n",
" <td id=\"T_b8c4d_row6_col6\" class=\"data row6 col6\" >0.971971</td>\n",
" <td id=\"T_b8c4d_row6_col7\" class=\"data row6 col7\" >0.969549</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_b8c4d_level0_row7\" class=\"row_heading level0 row7\" >knn</th>\n",
" <td id=\"T_b8c4d_row7_col0\" class=\"data row7 col0\" >0.972527</td>\n",
" <td id=\"T_b8c4d_row7_col1\" class=\"data row7 col1\" >0.966497</td>\n",
" <td id=\"T_b8c4d_row7_col2\" class=\"data row7 col2\" >0.962082</td>\n",
" <td id=\"T_b8c4d_row7_col3\" class=\"data row7 col3\" >0.954635</td>\n",
" <td id=\"T_b8c4d_row7_col4\" class=\"data row7 col4\" >0.972471</td>\n",
" <td id=\"T_b8c4d_row7_col5\" class=\"data row7 col5\" >0.966818</td>\n",
" <td id=\"T_b8c4d_row7_col6\" class=\"data row7 col6\" >0.967276</td>\n",
" <td id=\"T_b8c4d_row7_col7\" class=\"data row7 col7\" >0.960529</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1659b03a1e0>"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"#сортировка по столбцу Accuracy_test в порядке убывания\n",
"#чтобы увидеть модели с наивысшей точностью на тестовой выборке в начале таблицы\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient( #визуализация\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" #Указывает, какие столбцы будут окрашены\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_76f05_row0_col0, #T_76f05_row0_col1, #T_76f05_row1_col0, #T_76f05_row1_col1, #T_76f05_row2_col0, #T_76f05_row2_col1, #T_76f05_row3_col0, #T_76f05_row3_col1, #T_76f05_row4_col0, #T_76f05_row4_col1, #T_76f05_row5_col0, #T_76f05_row5_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_76f05_row0_col2, #T_76f05_row0_col3, #T_76f05_row0_col4, #T_76f05_row1_col2, #T_76f05_row1_col3, #T_76f05_row1_col4, #T_76f05_row2_col2, #T_76f05_row2_col3, #T_76f05_row2_col4, #T_76f05_row3_col2, #T_76f05_row3_col3, #T_76f05_row3_col4, #T_76f05_row4_col2, #T_76f05_row4_col3, #T_76f05_row4_col4, #T_76f05_row5_col2, #T_76f05_row5_col3, #T_76f05_row5_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_76f05_row6_col0, #T_76f05_row6_col1 {\n",
" background-color: #a2da37;\n",
" color: #000000;\n",
"}\n",
"#T_76f05_row6_col2 {\n",
" background-color: #d45270;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_76f05_row6_col3, #T_76f05_row6_col4 {\n",
" background-color: #d8576b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_76f05_row7_col0, #T_76f05_row7_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_76f05_row7_col2, #T_76f05_row7_col3, #T_76f05_row7_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_76f05\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_76f05_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_76f05_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_76f05_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_76f05_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_76f05_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_76f05_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_76f05_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_76f05_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_76f05_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_76f05_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_76f05_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_76f05_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_76f05_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_76f05_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_76f05_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_76f05_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_76f05_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_76f05_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_76f05_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_76f05_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_76f05_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_76f05_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_76f05_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_76f05_level0_row3\" class=\"row_heading level0 row3\" >naive_bayes</th>\n",
" <td id=\"T_76f05_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_76f05_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_76f05_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_76f05_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_76f05_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_76f05_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
" <td id=\"T_76f05_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_76f05_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_76f05_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_76f05_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_76f05_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_76f05_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
" <td id=\"T_76f05_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
" <td id=\"T_76f05_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
" <td id=\"T_76f05_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
" <td id=\"T_76f05_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
" <td id=\"T_76f05_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_76f05_level0_row6\" class=\"row_heading level0 row6\" >mlp</th>\n",
" <td id=\"T_76f05_row6_col0\" class=\"data row6 col0\" >0.999629</td>\n",
" <td id=\"T_76f05_row6_col1\" class=\"data row6 col1\" >0.999562</td>\n",
" <td id=\"T_76f05_row6_col2\" class=\"data row6 col2\" >0.999754</td>\n",
" <td id=\"T_76f05_row6_col3\" class=\"data row6 col3\" >0.999240</td>\n",
" <td id=\"T_76f05_row6_col4\" class=\"data row6 col4\" >0.999240</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_76f05_level0_row7\" class=\"row_heading level0 row7\" >knn</th>\n",
" <td id=\"T_76f05_row7_col0\" class=\"data row7 col0\" >0.980536</td>\n",
" <td id=\"T_76f05_row7_col1\" class=\"data row7 col1\" >0.976933</td>\n",
" <td id=\"T_76f05_row7_col2\" class=\"data row7 col2\" >0.995960</td>\n",
" <td id=\"T_76f05_row7_col3\" class=\"data row7 col3\" >0.960098</td>\n",
" <td id=\"T_76f05_row7_col4\" class=\"data row7 col4\" >0.960107</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1659f0e94f0>"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#выводим лучшую модель\n",
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 0'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>carat</th>\n",
" <th>Predicted</th>\n",
" <th>cut</th>\n",
" <th>color</th>\n",
" <th>clarity</th>\n",
" <th>depth</th>\n",
" <th>table</th>\n",
" <th>price</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>z</th>\n",
" <th>above_average_carat</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [carat, Predicted, cut, color, clarity, depth, table, price, x, y, z, above_average_carat]\n",
"Index: []"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#предобработка тестовой выборки\n",
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"#Получение предсказаний для лучшей модели\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"#Определение индексов ошибок\n",
"#Сравнивает истинные значения из y_test (столбец above_average_carat) с предсказанными значениями y_pred\n",
"#Это дает булев массив, где True означает ошибку\n",
"error_index = y_test[y_test[\"above_average_carat\"] != y_pred].index.tolist()\n",
"#выводим кол-во ошибок\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Пример использования обученной модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>carat</th>\n",
" <th>cut</th>\n",
" <th>color</th>\n",
" <th>clarity</th>\n",
" <th>depth</th>\n",
" <th>table</th>\n",
" <th>price</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>z</th>\n",
" <th>above_average_carat</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>4500</th>\n",
" <td>0.9</td>\n",
" <td>Premium</td>\n",
" <td>H</td>\n",
" <td>SI1</td>\n",
" <td>61.9</td>\n",
" <td>58.0</td>\n",
" <td>3629</td>\n",
" <td>6.2</td>\n",
" <td>6.15</td>\n",
" <td>3.82</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" carat cut color clarity depth table price x y z \\\n",
"4500 0.9 Premium H SI1 61.9 58.0 3629 6.2 6.15 3.82 \n",
"\n",
" above_average_carat \n",
"4500 1 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>Length_to_Width_Ratio</th>\n",
" <th>carat</th>\n",
" <th>depth</th>\n",
" <th>table</th>\n",
" <th>z</th>\n",
" <th>above_average_carat</th>\n",
" <th>cut_Good</th>\n",
" <th>cut_Ideal</th>\n",
" <th>...</th>\n",
" <th>color_I</th>\n",
" <th>color_J</th>\n",
" <th>clarity_IF</th>\n",
" <th>clarity_SI1</th>\n",
" <th>clarity_SI2</th>\n",
" <th>clarity_VS1</th>\n",
" <th>clarity_VS2</th>\n",
" <th>clarity_VVS1</th>\n",
" <th>clarity_VVS2</th>\n",
" <th>price</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>4500</th>\n",
" <td>0.420272</td>\n",
" <td>0.363352</td>\n",
" <td>1.156653</td>\n",
" <td>0.217442</td>\n",
" <td>0.10618</td>\n",
" <td>0.245753</td>\n",
" <td>0.399417</td>\n",
" <td>1.168162</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>3629.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1 rows × 26 columns</p>\n",
"</div>"
],
"text/plain": [
" x y Length_to_Width_Ratio carat depth table \\\n",
"4500 0.420272 0.363352 1.156653 0.217442 0.10618 0.245753 \n",
"\n",
" z above_average_carat cut_Good cut_Ideal ... color_I \\\n",
"4500 0.399417 1.168162 0.0 0.0 ... 0.0 \n",
"\n",
" color_J clarity_IF clarity_SI1 clarity_SI2 clarity_VS1 clarity_VS2 \\\n",
"4500 0.0 0.0 1.0 0.0 0.0 0.0 \n",
"\n",
" clarity_VVS1 clarity_VVS2 price \n",
"4500 0.0 0.0 3629.0 \n",
"\n",
"[1 rows x 26 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: 1 (proba: [4.39873930e-04 9.99560126e-01])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 1'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 4500\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"#Получение вероятностей предсказания\n",
"result_proba = model.predict_proba(test)[0]\n",
"#Получение предсказания\n",
"result = model.predict(test)[0]\n",
"#Получение реального значения\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
"execution_count": 209,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 2,\n",
" 'model__max_features': 'sqrt',\n",
" 'model__n_estimators': 20}"
]
},
"execution_count": 209,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"#используем модель случайного леса\n",
"optimized_model_type = \"random_forest\"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"#Определение сетки гиперпараметров\n",
"param_grid = {\n",
" #Количество деревьев в лесу\n",
" \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" #Количество столбцов\n",
" \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
" #Максимальная глубина дерева\n",
" \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10],\n",
" #Критерий оценки качества разбиения\n",
" \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
"}\n",
"#Указываем модель, которую мы хотим оптимизировать\n",
"#Передаем сетку гиперпараметров для поиска\n",
"#Указываем, что хотим использовать все доступные процессоры для параллельной обработки\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"#Метод fit запускает процесс поиска по сетке и обучает модель с различными комбинациями гиперпараметров\n",
"#метод ravel() для преобразования меток в одномерный массив\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"#выводим наилучшие значения гиперпараметров\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
"execution_count": 210,
"metadata": {},
"outputs": [],
"source": [
"optimized_model = ensemble.RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=7,\n",
" max_features=\"sqrt\",\n",
" n_estimators=30,\n",
")\n",
"\n",
"result = {}\n",
"\n",
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
"execution_count": 212,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_703b0_row0_col0, #T_703b0_row0_col1, #T_703b0_row0_col2, #T_703b0_row0_col3, #T_703b0_row1_col0, #T_703b0_row1_col1, #T_703b0_row1_col2, #T_703b0_row1_col3 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_703b0_row0_col4, #T_703b0_row0_col5, #T_703b0_row0_col6, #T_703b0_row0_col7, #T_703b0_row1_col4, #T_703b0_row1_col5, #T_703b0_row1_col6, #T_703b0_row1_col7 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_703b0\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_703b0_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_703b0_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_703b0_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_703b0_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_703b0_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_703b0_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_703b0_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_703b0_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" <th class=\"blank col7\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_703b0_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_703b0_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_703b0_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_703b0_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_703b0_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_703b0_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_703b0_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_703b0_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_703b0_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_703b0_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_703b0_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_703b0_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_703b0_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_703b0_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_703b0_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_703b0_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_703b0_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_703b0_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x21f4cb4fe90>"
]
},
"execution_count": 212,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 213,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_0ba78_row0_col0, #T_0ba78_row0_col1, #T_0ba78_row1_col0, #T_0ba78_row1_col1 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0ba78_row0_col2, #T_0ba78_row0_col3, #T_0ba78_row0_col4, #T_0ba78_row1_col2, #T_0ba78_row1_col3, #T_0ba78_row1_col4 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_0ba78\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_0ba78_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_0ba78_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_0ba78_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_0ba78_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_0ba78_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_0ba78_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_0ba78_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_0ba78_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_0ba78_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_0ba78_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_0ba78_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0ba78_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_0ba78_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_0ba78_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_0ba78_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_0ba78_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_0ba78_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x21f4cb4ef30>"
]
},
"execution_count": 213,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 215,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2kAAAGsCAYAAABHMu+IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABRUElEQVR4nO3dfVwVZf7/8fcBPIDAAVEBSSSMvGHz3k1ZLbMQNLcybdtaNU3T1cVKzZvcLe8qbW3NshutTNHd/Jnb3Te1NLNVU8nUsszUvCssBSsDROX2nN8fxqnjXYMMnOH4ej4e80hmLq5zDfng7Weua2ZsLpfLJQAAAACAJfh5ewAAAAAAgF9QpAEAAACAhVCkAQAAAICFUKQBAAAAgIVQpAEAAACAhVCkAQAAAICFUKQBAAAAgIVQpAEAAACAhQR4ewAAgHMrLCxUcXGxaf3Z7XYFBQWZ1h8AABVBrhlHkQYAFlRYWKiE+FBlHy0zrc+YmBgdPHjQZwMNAGBd5FrFUKQBgAUVFxcr+2iZDm6LlyOs8ivT8487ldDuGxUXF/tkmAEArI1cqxiKNACwMEeYnylhBgCAFZBrxlCkAYCFlbmcKnOZ0w8AAN5GrhlDkQYAFuaUS05VPs3M6AMAgMoi14xhrhEAAAAALISZNACwMKecMmNBhzm9AABQOeSaMRRpAGBhZS6XylyVX9JhRh8AAFQWuWYMyx0BAAAAwEKYSQMAC+MGawCALyHXjGEmDQAszCmXykzYLibMvvvuO/Xr109169ZVcHCwWrRooa1bt7qPu1wuTZw4UQ0aNFBwcLBSUlK0d+9ejz6OHTumvn37yuFwKCIiQoMHD1ZBQYFHm88//1zXXHONgoKCFBcXpxkzZlzcDwsAYHnkmjEUaQCAs/z000/q1KmTatWqpXfffVdffvmlZs6cqTp16rjbzJgxQ7Nnz9bcuXO1efNmhYSEKC0tTYWFhe42ffv21c6dO7V69WotX75c69ev19ChQ93H8/PzlZqaqvj4eG3btk1PPPGEJk+erBdffLFazxcA4NtqWq7ZXC4fv+sOAGqg/Px8hYeHa//uGIWFVf562vHjTl3RLFt5eXlyOBy/2f7BBx/Uxo0b9eGHH57zuMvlUmxsrB544AGNGTNGkpSXl6fo6GhlZGTojjvu0K5du5SUlKQtW7aoffv2kqSVK1fqxhtv1LfffqvY2FjNmTNH//jHP5SdnS273e7+7Lfeeku7d++u9HkDAKyBXKtYrjGTBgCXkPz8fI+tqKjonO3efvtttW/fXn/6058UFRWlNm3a6KWXXnIfP3jwoLKzs5WSkuLeFx4erg4dOigzM1OSlJmZqYiICHeQSVJKSor8/Py0efNmd5trr73WHWSSlJaWpj179uinn34y9dwBAL7HV3ONIg0ALKz8UcVmbJIUFxen8PBw9zZ9+vRzfu6BAwc0Z84cXXnllVq1apWGDx+u++67TwsXLpQkZWdnS5Kio6M9vi86Otp9LDs7W1FRUR7HAwICFBkZ6dHmXH38+jMAAL6DXDOGpzsCgIU5f97M6EeSDh065LEsJDAw8NztnU61b99e06ZNkyS1adNGX3zxhebOnasBAwaYMCIAwKWIXDOGmTQAuIQ4HA6P7Xxh1qBBAyUlJXnsa968ubKysiRJMTExkqScnByPNjk5Oe5jMTExOnr0qMfx0tJSHTt2zKPNufr49WcAAHA+vpprFGkAYGFmPKa4fKuITp06ac+ePR77vvrqK8XHx0uSEhISFBMTozVr1riP5+fna/PmzUpOTpYkJScnKzc3V9u2bXO3+eCDD+R0OtWhQwd3m/Xr16ukpMTdZvXq1WratKnHE7cAAL6BXDOGIg0ALKzMZd5WEaNGjdJHH32kadOmad++fVq8eLFefPFFpaenS5JsNptGjhypRx99VG+//bZ27Nihu+66S7GxserVq5ek01cou3fvriFDhujjjz/Wxo0bNWLECN1xxx2KjY2VJP3lL3+R3W7X4MGDtXPnTr366qt6+umnNXr0aDN/jAAAiyDXjOER/ABgQeWPKv78yyjTHlXcMumo4UcVS9Ly5cs1YcIE7d27VwkJCRo9erSGDBniPu5yuTRp0iS9+OKLys3NVefOnfX888+rSZMm7jbHjh3TiBEjtGzZMvn5+alPnz6aPXu2QkND3W0+//xzpaena8uWLapXr57uvfdejR8/vtLnDACwDnKtYrlGkQYAFlQeZttNDLPWFQwzAADMQq5VDE93BAALc8qmMtlM6QcAAG8j14zhnjQAAAAAsBBm0gDAwpyu05sZ/QAA4G3kmjEUaQBgYWUmLQsxow8AACqLXDOG5Y4AAAAAYCHMpAGAhXHFEQDgS8g1Y5hJAwAAAAALYSYNACzM6bLJ6TLhUcUm9AEAQGWRa8ZQpAGAhbEsBADgS8g1Y1juCAAAAAAWwkwaAFhYmfxUZsL1tDITxgIAQGWRa8ZQpAGAhblMWrvv8vG1+wCAmoFcM4bljgAAAABgIcykAYCFcYM1AMCXkGvGUKQBgIWVufxU5jJh7b7LhMEAAFBJ5JoxLHcEAAAAAAthJg0ALMwpm5wmXE9zyscvOQIAagRyzRhm0gAAAADAQphJAwAL4wZrAIAvIdeMoUgDAAsz7wZr314WAgCoGcg1Y1juCAAAAAAWwkwaAFjY6RusK7+kw4w+AACoLHLNGIo0ALAwp/xUxlOwAAA+glwzhuWOAAAAAGAhzKQBgIVxgzUAwJeQa8ZQpAGAhTnlx0s/AQA+g1wzhuWOAAAAAGAhzKQBgIWVuWwqc5nw0k8T+gAAoLLINWMo0gDAwspMegpWmY8vCwEA1AzkmjEsdwQAAAAAC2EmDQAszOnyk9OEp2A5ffwpWACAmoFcM4aZNAAAAACwEGbSAMDCWLsPAPAl5JoxFGkAYGFOmfMEK2flhwIAQKWRa8aw3BEAAAAALISZNACwMKf85DThepoZfQAAUFnkmjEUaQBgYWUuP5WZ8BQsM/oAAKCyyDVjfPvsAAAAAKCGYSYNACzMKZucMuMG68r3AQBAZZFrxlCkAYCFsSwEAOBLyDVjfPvsAAAAAKCGYSYNACzMvJd+ck0OAOB95Joxvn12AAAAAFDDMJNmgNPp1OHDhxUWFiabzbdvUgRQeS6XS8ePH1dsbKz8/Cp3LczpssnpMuEGaxP6gO8g1wBUBLlW/SjSDDh8+LDi4uK8PQwANcyhQ4fUsGHDSvXhNGlZiK+/9BMVQ64BuBjkWvWhSDMgLCxMkvTNJ5fLEerbfyFQcbc2aeHtIcBiSlWiDXrH/bsDsBpyDRdCruFM5Fr1o0gzoHwpiCPUT44wwgyeAmy1vD0EWI3r9H/MWEbmdPnJacJjhs3oA76DXMOFkGs4C7lW7SjSAMDCymRTmQkv7DSjDwAAKotcM8a3S1AAAAAAqGGYSQMAC2NZCADAl5BrxlCkAYCFlcmcJR1llR8KAACVRq4Z49slKAAAAADUMMykAYCFsSwEAOBLyDVjfPvsAKCGK3P5mbZVxOTJk2Wz2Ty2Zs2auY8XFhYqPT1ddevWVWhoqPr06aOcnByPPrKystSzZ0/Vrl1bUVFRGjt2rEpLSz3arF27Vm3btlVgYKASExOVkZFx0T8rAID1kWvGUKQBAM7pd7/7nY4cOeLeNmzY4D42atQoLVu2TP/973+1bt06HT58WL1793YfLysrU8+ePVVcXKxNmzZp4cKFysjI0MSJE91tDh48qJ49e6pr167avn27Ro4cqXvuuUerVq2q1vMEAFwaalKusdwRACzMJZucJtxg7bqIPgICAhQTE3PW/ry8PL388stavHixrr/+eknSggUL1Lx5c3300Ufq2LGj3nvvPX355Zd6//33FR0drdatW+uRRx7R+PHjNXnyZNntds2dO1cJCQmaOXOmJKl58+basGGDZs2apbS0tMqdMADAksg1Y5hJA4BLSH5+vsdWVFR03rZ79+5VbGysGjdurL59+yorK0uStG3bNpWUlCglJcXdtlmzZmrUqJEyMzMlSZmZmWrRooWio6PdbdLS0pSfn6+dO3e62/y6j/I25X0
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Less\", \"More\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}