MII/mai/lab4.ipynb

2923 lines
238 KiB
Plaintext
Raw Normal View History

2024-10-25 22:20:23 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Лабораторная работа 4\n",
"\n",
"Датасет - **Цены на бриллианты**\thttps://www.kaggle.com/datasets/nancyalaswad90/diamonds-prices\n",
"\n",
"1. **carat**: Вес бриллианта в каратах\n",
"2. **cut**: Качество огранки.\n",
"3. **color**: Цвет бриллианта\n",
"4. **clarity**: Чистота бриллианта\n",
"5. **depth**: Процент глубины бриллианта\n",
"6. **table**: Процент ширины бриллианта\n",
"7. **price**: Цена бриллианта в долларах США\n",
"8. **x**: Длина бриллианта в миллиметрах\n",
"9. **y**: Ширина бриллианта в миллиметрах\n",
"10. **z**: Глубина бриллианта в миллиметрах"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Бизнес-цели**: \n",
"1. Прогнозирование цены бриллиантов на основании характеристик.\n",
"2. Анализ частотности и сочетания характеристик бриллиантов, которые пользуются наибольшим спросом, чтобы лучше планировать запасы. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Загрузка набора данных"
]
},
{
"cell_type": "code",
"execution_count": 190,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Среднее значение поля 'карат': 0.7979346717831785\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>carat</th>\n",
" <th>cut</th>\n",
" <th>color</th>\n",
" <th>clarity</th>\n",
" <th>depth</th>\n",
" <th>table</th>\n",
" <th>price</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>z</th>\n",
" <th>above_average_carat</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.23</td>\n",
" <td>Ideal</td>\n",
" <td>E</td>\n",
" <td>SI2</td>\n",
" <td>61.5</td>\n",
" <td>55.0</td>\n",
" <td>326</td>\n",
" <td>3.95</td>\n",
" <td>3.98</td>\n",
" <td>2.43</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.21</td>\n",
" <td>Premium</td>\n",
" <td>E</td>\n",
" <td>SI1</td>\n",
" <td>59.8</td>\n",
" <td>61.0</td>\n",
" <td>326</td>\n",
" <td>3.89</td>\n",
" <td>3.84</td>\n",
" <td>2.31</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.23</td>\n",
" <td>Good</td>\n",
" <td>E</td>\n",
" <td>VS1</td>\n",
" <td>56.9</td>\n",
" <td>65.0</td>\n",
" <td>327</td>\n",
" <td>4.05</td>\n",
" <td>4.07</td>\n",
" <td>2.31</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.29</td>\n",
" <td>Premium</td>\n",
" <td>I</td>\n",
" <td>VS2</td>\n",
" <td>62.4</td>\n",
" <td>58.0</td>\n",
" <td>334</td>\n",
" <td>4.20</td>\n",
" <td>4.23</td>\n",
" <td>2.63</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0.31</td>\n",
" <td>Good</td>\n",
" <td>J</td>\n",
" <td>SI2</td>\n",
" <td>63.3</td>\n",
" <td>58.0</td>\n",
" <td>335</td>\n",
" <td>4.34</td>\n",
" <td>4.35</td>\n",
" <td>2.75</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53939</th>\n",
" <td>0.86</td>\n",
" <td>Premium</td>\n",
" <td>H</td>\n",
" <td>SI2</td>\n",
" <td>61.0</td>\n",
" <td>58.0</td>\n",
" <td>2757</td>\n",
" <td>6.15</td>\n",
" <td>6.12</td>\n",
" <td>3.74</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53940</th>\n",
" <td>0.75</td>\n",
" <td>Ideal</td>\n",
" <td>D</td>\n",
" <td>SI2</td>\n",
" <td>62.2</td>\n",
" <td>55.0</td>\n",
" <td>2757</td>\n",
" <td>5.83</td>\n",
" <td>5.87</td>\n",
" <td>3.64</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53941</th>\n",
" <td>0.71</td>\n",
" <td>Premium</td>\n",
" <td>E</td>\n",
" <td>SI1</td>\n",
" <td>60.5</td>\n",
" <td>55.0</td>\n",
" <td>2756</td>\n",
" <td>5.79</td>\n",
" <td>5.74</td>\n",
" <td>3.49</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53942</th>\n",
" <td>0.71</td>\n",
" <td>Premium</td>\n",
" <td>F</td>\n",
" <td>SI1</td>\n",
" <td>59.8</td>\n",
" <td>62.0</td>\n",
" <td>2756</td>\n",
" <td>5.74</td>\n",
" <td>5.73</td>\n",
" <td>3.43</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53943</th>\n",
" <td>0.70</td>\n",
" <td>Very Good</td>\n",
" <td>E</td>\n",
" <td>VS2</td>\n",
" <td>60.5</td>\n",
" <td>59.0</td>\n",
" <td>2757</td>\n",
" <td>5.71</td>\n",
" <td>5.76</td>\n",
" <td>3.47</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>53943 rows × 11 columns</p>\n",
"</div>"
],
"text/plain": [
" carat cut color clarity depth table price x y z \\\n",
"id \n",
"1 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43 \n",
"2 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31 \n",
"3 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31 \n",
"4 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63 \n",
"5 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75 \n",
"... ... ... ... ... ... ... ... ... ... ... \n",
"53939 0.86 Premium H SI2 61.0 58.0 2757 6.15 6.12 3.74 \n",
"53940 0.75 Ideal D SI2 62.2 55.0 2757 5.83 5.87 3.64 \n",
"53941 0.71 Premium E SI1 60.5 55.0 2756 5.79 5.74 3.49 \n",
"53942 0.71 Premium F SI1 59.8 62.0 2756 5.74 5.73 3.43 \n",
"53943 0.70 Very Good E VS2 60.5 59.0 2757 5.71 5.76 3.47 \n",
"\n",
" above_average_carat \n",
"id \n",
"1 0 \n",
"2 0 \n",
"3 0 \n",
"4 0 \n",
"5 0 \n",
"... ... \n",
"53939 1 \n",
"53940 0 \n",
"53941 0 \n",
"53942 0 \n",
"53943 0 \n",
"\n",
"[53943 rows x 11 columns]"
]
},
"execution_count": 190,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"df = pd.read_csv(\"data/Diamonds.csv\", index_col=\"id\")\n",
"\n",
"random_state=42\n",
"\n",
"average_carat = df['carat'].mean()\n",
"\n",
"print(f\"Среднее значение поля 'карат': {average_carat}\")\n",
"\n",
"average_carat = df['carat'].mean()\n",
"df['above_average_carat'] = (df['carat'] > average_carat).astype(int)\n",
"\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
"\n",
"Целевой признак -- Cut"
]
},
{
"cell_type": "code",
"execution_count": 191,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>carat</th>\n",
" <th>cut</th>\n",
" <th>color</th>\n",
" <th>clarity</th>\n",
" <th>depth</th>\n",
" <th>table</th>\n",
" <th>price</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>z</th>\n",
" <th>above_average_carat</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>38836</th>\n",
" <td>0.40</td>\n",
" <td>Very Good</td>\n",
" <td>F</td>\n",
" <td>VVS2</td>\n",
" <td>62.0</td>\n",
" <td>56.0</td>\n",
" <td>1049</td>\n",
" <td>4.71</td>\n",
" <td>4.74</td>\n",
" <td>2.93</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30260</th>\n",
" <td>0.40</td>\n",
" <td>Very Good</td>\n",
" <td>E</td>\n",
" <td>SI1</td>\n",
" <td>63.0</td>\n",
" <td>57.0</td>\n",
" <td>725</td>\n",
" <td>4.68</td>\n",
" <td>4.71</td>\n",
" <td>2.96</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33169</th>\n",
" <td>0.36</td>\n",
" <td>Ideal</td>\n",
" <td>E</td>\n",
" <td>VS1</td>\n",
" <td>61.8</td>\n",
" <td>56.0</td>\n",
" <td>817</td>\n",
" <td>4.55</td>\n",
" <td>4.58</td>\n",
" <td>2.82</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1029</th>\n",
" <td>0.70</td>\n",
" <td>Very Good</td>\n",
" <td>E</td>\n",
" <td>VS1</td>\n",
" <td>58.4</td>\n",
" <td>59.0</td>\n",
" <td>2904</td>\n",
" <td>5.83</td>\n",
" <td>5.91</td>\n",
" <td>3.43</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53809</th>\n",
" <td>0.81</td>\n",
" <td>Very Good</td>\n",
" <td>G</td>\n",
" <td>SI1</td>\n",
" <td>60.7</td>\n",
" <td>56.0</td>\n",
" <td>2733</td>\n",
" <td>6.06</td>\n",
" <td>6.09</td>\n",
" <td>3.69</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2937</th>\n",
" <td>0.77</td>\n",
" <td>Good</td>\n",
" <td>E</td>\n",
" <td>VS2</td>\n",
" <td>63.4</td>\n",
" <td>57.0</td>\n",
" <td>3291</td>\n",
" <td>5.80</td>\n",
" <td>5.84</td>\n",
" <td>3.69</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7514</th>\n",
" <td>0.90</td>\n",
" <td>Good</td>\n",
" <td>F</td>\n",
" <td>SI1</td>\n",
" <td>61.8</td>\n",
" <td>63.0</td>\n",
" <td>4241</td>\n",
" <td>6.21</td>\n",
" <td>6.18</td>\n",
" <td>3.83</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48344</th>\n",
" <td>0.56</td>\n",
" <td>Ideal</td>\n",
" <td>H</td>\n",
" <td>VVS1</td>\n",
" <td>62.1</td>\n",
" <td>53.8</td>\n",
" <td>1961</td>\n",
" <td>5.27</td>\n",
" <td>5.33</td>\n",
" <td>3.29</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3212</th>\n",
" <td>0.70</td>\n",
" <td>Premium</td>\n",
" <td>F</td>\n",
" <td>VVS1</td>\n",
" <td>61.8</td>\n",
" <td>60.0</td>\n",
" <td>3348</td>\n",
" <td>5.67</td>\n",
" <td>5.63</td>\n",
" <td>3.49</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35654</th>\n",
" <td>0.31</td>\n",
" <td>Very Good</td>\n",
" <td>G</td>\n",
" <td>VVS2</td>\n",
" <td>63.1</td>\n",
" <td>57.0</td>\n",
" <td>907</td>\n",
" <td>4.32</td>\n",
" <td>4.30</td>\n",
" <td>2.72</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>43154 rows × 11 columns</p>\n",
"</div>"
],
"text/plain": [
" carat cut color clarity depth table price x y z \\\n",
"id \n",
"38836 0.40 Very Good F VVS2 62.0 56.0 1049 4.71 4.74 2.93 \n",
"30260 0.40 Very Good E SI1 63.0 57.0 725 4.68 4.71 2.96 \n",
"33169 0.36 Ideal E VS1 61.8 56.0 817 4.55 4.58 2.82 \n",
"1029 0.70 Very Good E VS1 58.4 59.0 2904 5.83 5.91 3.43 \n",
"53809 0.81 Very Good G SI1 60.7 56.0 2733 6.06 6.09 3.69 \n",
"... ... ... ... ... ... ... ... ... ... ... \n",
"2937 0.77 Good E VS2 63.4 57.0 3291 5.80 5.84 3.69 \n",
"7514 0.90 Good F SI1 61.8 63.0 4241 6.21 6.18 3.83 \n",
"48344 0.56 Ideal H VVS1 62.1 53.8 1961 5.27 5.33 3.29 \n",
"3212 0.70 Premium F VVS1 61.8 60.0 3348 5.67 5.63 3.49 \n",
"35654 0.31 Very Good G VVS2 63.1 57.0 907 4.32 4.30 2.72 \n",
"\n",
" above_average_carat \n",
"id \n",
"38836 0 \n",
"30260 0 \n",
"33169 0 \n",
"1029 0 \n",
"53809 1 \n",
"... ... \n",
"2937 0 \n",
"7514 1 \n",
"48344 0 \n",
"3212 0 \n",
"35654 0 \n",
"\n",
"[43154 rows x 11 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>above_average_carat</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>38836</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30260</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33169</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1029</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53809</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2937</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7514</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48344</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3212</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35654</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>43154 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" above_average_carat\n",
"id \n",
"38836 0\n",
"30260 0\n",
"33169 0\n",
"1029 0\n",
"53809 1\n",
"... ...\n",
"2937 0\n",
"7514 1\n",
"48344 0\n",
"3212 0\n",
"35654 0\n",
"\n",
"[43154 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>carat</th>\n",
" <th>cut</th>\n",
" <th>color</th>\n",
" <th>clarity</th>\n",
" <th>depth</th>\n",
" <th>table</th>\n",
" <th>price</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>z</th>\n",
" <th>above_average_carat</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>32452</th>\n",
" <td>0.39</td>\n",
" <td>Very Good</td>\n",
" <td>E</td>\n",
" <td>VS2</td>\n",
" <td>60.9</td>\n",
" <td>58.0</td>\n",
" <td>793</td>\n",
" <td>4.72</td>\n",
" <td>4.77</td>\n",
" <td>2.89</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2432</th>\n",
" <td>0.72</td>\n",
" <td>Very Good</td>\n",
" <td>E</td>\n",
" <td>SI1</td>\n",
" <td>63.3</td>\n",
" <td>56.0</td>\n",
" <td>3183</td>\n",
" <td>5.67</td>\n",
" <td>5.71</td>\n",
" <td>3.60</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16456</th>\n",
" <td>1.21</td>\n",
" <td>Ideal</td>\n",
" <td>H</td>\n",
" <td>SI1</td>\n",
" <td>62.1</td>\n",
" <td>59.0</td>\n",
" <td>6573</td>\n",
" <td>6.81</td>\n",
" <td>6.75</td>\n",
" <td>4.21</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46045</th>\n",
" <td>0.56</td>\n",
" <td>Ideal</td>\n",
" <td>D</td>\n",
" <td>SI1</td>\n",
" <td>62.5</td>\n",
" <td>56.0</td>\n",
" <td>1729</td>\n",
" <td>5.28</td>\n",
" <td>5.24</td>\n",
" <td>3.29</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11115</th>\n",
" <td>1.00</td>\n",
" <td>Good</td>\n",
" <td>E</td>\n",
" <td>SI1</td>\n",
" <td>62.4</td>\n",
" <td>59.0</td>\n",
" <td>4936</td>\n",
" <td>6.35</td>\n",
" <td>6.40</td>\n",
" <td>3.98</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40250</th>\n",
" <td>0.50</td>\n",
" <td>Premium</td>\n",
" <td>F</td>\n",
" <td>SI1</td>\n",
" <td>59.6</td>\n",
" <td>61.0</td>\n",
" <td>1125</td>\n",
" <td>5.15</td>\n",
" <td>5.12</td>\n",
" <td>3.06</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3308</th>\n",
" <td>0.73</td>\n",
" <td>Ideal</td>\n",
" <td>E</td>\n",
" <td>VS1</td>\n",
" <td>62.3</td>\n",
" <td>56.0</td>\n",
" <td>3370</td>\n",
" <td>5.75</td>\n",
" <td>5.80</td>\n",
" <td>3.60</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7894</th>\n",
" <td>1.12</td>\n",
" <td>Very Good</td>\n",
" <td>I</td>\n",
" <td>SI1</td>\n",
" <td>60.6</td>\n",
" <td>60.0</td>\n",
" <td>4312</td>\n",
" <td>6.73</td>\n",
" <td>6.77</td>\n",
" <td>4.09</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21368</th>\n",
" <td>0.36</td>\n",
" <td>Ideal</td>\n",
" <td>D</td>\n",
" <td>SI1</td>\n",
" <td>62.2</td>\n",
" <td>53.0</td>\n",
" <td>626</td>\n",
" <td>4.57</td>\n",
" <td>4.59</td>\n",
" <td>2.85</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46144</th>\n",
" <td>0.50</td>\n",
" <td>Premium</td>\n",
" <td>E</td>\n",
" <td>VS2</td>\n",
" <td>61.3</td>\n",
" <td>59.0</td>\n",
" <td>1746</td>\n",
" <td>5.10</td>\n",
" <td>5.05</td>\n",
" <td>3.11</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>10789 rows × 11 columns</p>\n",
"</div>"
],
"text/plain": [
" carat cut color clarity depth table price x y z \\\n",
"id \n",
"32452 0.39 Very Good E VS2 60.9 58.0 793 4.72 4.77 2.89 \n",
"2432 0.72 Very Good E SI1 63.3 56.0 3183 5.67 5.71 3.60 \n",
"16456 1.21 Ideal H SI1 62.1 59.0 6573 6.81 6.75 4.21 \n",
"46045 0.56 Ideal D SI1 62.5 56.0 1729 5.28 5.24 3.29 \n",
"11115 1.00 Good E SI1 62.4 59.0 4936 6.35 6.40 3.98 \n",
"... ... ... ... ... ... ... ... ... ... ... \n",
"40250 0.50 Premium F SI1 59.6 61.0 1125 5.15 5.12 3.06 \n",
"3308 0.73 Ideal E VS1 62.3 56.0 3370 5.75 5.80 3.60 \n",
"7894 1.12 Very Good I SI1 60.6 60.0 4312 6.73 6.77 4.09 \n",
"21368 0.36 Ideal D SI1 62.2 53.0 626 4.57 4.59 2.85 \n",
"46144 0.50 Premium E VS2 61.3 59.0 1746 5.10 5.05 3.11 \n",
"\n",
" above_average_carat \n",
"id \n",
"32452 0 \n",
"2432 0 \n",
"16456 1 \n",
"46045 0 \n",
"11115 1 \n",
"... ... \n",
"40250 0 \n",
"3308 0 \n",
"7894 1 \n",
"21368 0 \n",
"46144 0 \n",
"\n",
"[10789 rows x 11 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>above_average_carat</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>32452</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2432</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16456</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46045</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11115</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40250</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3308</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7894</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21368</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46144</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>10789 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" above_average_carat\n",
"id \n",
"32452 0\n",
"2432 0\n",
"16456 1\n",
"46045 0\n",
"11115 1\n",
"... ...\n",
"40250 0\n",
"3308 0\n",
"7894 1\n",
"21368 0\n",
"46144 0\n",
"\n",
"[10789 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \"\"\"\n",
" Splits a Pandas dataframe into three subsets (train, val, and test)\n",
" following fractional ratios provided by the user, where each subset is\n",
" stratified by the values in a specific column (that is, each subset has\n",
" the same relative frequency of the values in the column). It performs this\n",
" splitting by running train_test_split() twice.\n",
" Parameters\n",
" ----------\n",
" df_input : Pandas dataframe\n",
" Input dataframe to be split.\n",
" stratify_colname : str\n",
" The name of the column that will be used for stratification. Usually\n",
" this column would be for the label.\n",
" frac_train : float\n",
" frac_val : float\n",
" frac_test : float\n",
" The ratios with which the dataframe will be split into train, val, and\n",
" test data. The values should be expressed as float fractions and should\n",
" sum to 1.0.\n",
" random_state : int, None, or RandomStateInstance\n",
" Value to be passed to train_test_split().\n",
" Returns\n",
" -------\n",
" df_train, df_val, df_test :\n",
" Dataframes containing the three splits.\n",
" \"\"\"\n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" X = df_input # Contains all columns.\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
" # Split original dataframe into train and temp dataframes.\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
" # Split the temp dataframe into val and test dataframes.\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"above_average_carat\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование конвейера для классификации данных\n",
"\n",
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
"\n",
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
"\n",
"features_preprocessing -- трансформер для предобработки признаков\n",
"\n",
"features_engineering -- трансформер для конструирования признаков\n",
"\n",
"drop_columns -- трансформер для удаления колонок\n",
"\n",
"pipeline_end -- основной конвейер предобработки данных и конструирования признако"
]
},
{
"cell_type": "code",
"execution_count": 192,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"class DaimondFeatures(BaseEstimator, TransformerMixin):\n",
" def __init__(self):\n",
" pass\n",
" def fit(self, X, y=None):\n",
" return self\n",
" def transform(self, X, y=None):\n",
" X[\"Length_to_Width_Ratio\"] = X[\"x\"] / X[\"y\"]\n",
" return X\n",
" def get_feature_names_out(self, features_in):\n",
" return np.append(features_in, [\"Length_to_Width_Ratio\"], axis=0)\n",
" \n",
"\n",
"columns_to_drop = []\n",
"num_columns = [\"carat\", \"depth\", \"table\", \"x\", \"y\", \"z\", \"above_average_carat\"]\n",
"cat_columns = [\"cut\", \"color\", \"clarity\"]\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"features_engineering = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"add_features\", DaimondFeatures(), [\"x\", \"y\"]),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"features_postprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_cat\", preprocessing_cat, [\"Cabin_type\"]),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"features_engineering\", features_engineering),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Демонстрация работы конвейера для предобработки данных при классификации"
]
},
{
"cell_type": "code",
"execution_count": 193,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>Length_to_Width_Ratio</th>\n",
" <th>carat</th>\n",
" <th>depth</th>\n",
" <th>table</th>\n",
" <th>z</th>\n",
" <th>above_average_carat</th>\n",
" <th>cut_Good</th>\n",
" <th>cut_Ideal</th>\n",
" <th>...</th>\n",
" <th>color_I</th>\n",
" <th>color_J</th>\n",
" <th>clarity_IF</th>\n",
" <th>clarity_SI1</th>\n",
" <th>clarity_SI2</th>\n",
" <th>clarity_VS1</th>\n",
" <th>clarity_VS2</th>\n",
" <th>clarity_VVS1</th>\n",
" <th>clarity_VVS2</th>\n",
" <th>price</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>38836</th>\n",
" <td>-0.907744</td>\n",
" <td>-0.863476</td>\n",
" <td>1.051267</td>\n",
" <td>-0.837490</td>\n",
" <td>0.176170</td>\n",
" <td>-0.648004</td>\n",
" <td>-0.857040</td>\n",
" <td>-0.856046</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1049</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30260</th>\n",
" <td>-0.934483</td>\n",
" <td>-0.889579</td>\n",
" <td>1.050478</td>\n",
" <td>-0.837490</td>\n",
" <td>0.876071</td>\n",
" <td>-0.201125</td>\n",
" <td>-0.814688</td>\n",
" <td>-0.856046</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>725</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33169</th>\n",
" <td>-1.050350</td>\n",
" <td>-1.002691</td>\n",
" <td>1.047532</td>\n",
" <td>-0.921885</td>\n",
" <td>0.036190</td>\n",
" <td>-0.648004</td>\n",
" <td>-1.012333</td>\n",
" <td>-0.856046</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>817</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1029</th>\n",
" <td>0.090496</td>\n",
" <td>0.154530</td>\n",
" <td>0.585622</td>\n",
" <td>-0.204531</td>\n",
" <td>-2.343471</td>\n",
" <td>0.692631</td>\n",
" <td>-0.151165</td>\n",
" <td>-0.856046</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2904</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53809</th>\n",
" <td>0.295492</td>\n",
" <td>0.311147</td>\n",
" <td>0.949688</td>\n",
" <td>0.027554</td>\n",
" <td>-0.733700</td>\n",
" <td>-0.648004</td>\n",
" <td>0.215890</td>\n",
" <td>1.168162</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2733</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2937</th>\n",
" <td>0.063758</td>\n",
" <td>0.093624</td>\n",
" <td>0.680999</td>\n",
" <td>-0.056841</td>\n",
" <td>1.156031</td>\n",
" <td>-0.201125</td>\n",
" <td>0.215890</td>\n",
" <td>-0.856046</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>3291</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7514</th>\n",
" <td>0.429185</td>\n",
" <td>0.389455</td>\n",
" <td>1.102015</td>\n",
" <td>0.217442</td>\n",
" <td>0.036190</td>\n",
" <td>2.480145</td>\n",
" <td>0.413535</td>\n",
" <td>1.168162</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>4241</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48344</th>\n",
" <td>-0.408624</td>\n",
" <td>-0.350123</td>\n",
" <td>1.167088</td>\n",
" <td>-0.499912</td>\n",
" <td>0.246160</td>\n",
" <td>-1.631136</td>\n",
" <td>-0.348810</td>\n",
" <td>-0.856046</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1961</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3212</th>\n",
" <td>-0.052109</td>\n",
" <td>-0.089095</td>\n",
" <td>0.584874</td>\n",
" <td>-0.204531</td>\n",
" <td>0.036190</td>\n",
" <td>1.139510</td>\n",
" <td>-0.066460</td>\n",
" <td>-0.856046</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>3348</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35654</th>\n",
" <td>-1.255346</td>\n",
" <td>-1.246316</td>\n",
" <td>1.007245</td>\n",
" <td>-1.027378</td>\n",
" <td>0.946061</td>\n",
" <td>-0.201125</td>\n",
" <td>-1.153508</td>\n",
" <td>-0.856046</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>907</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>43154 rows × 26 columns</p>\n",
"</div>"
],
"text/plain": [
" x y Length_to_Width_Ratio carat depth \\\n",
"id \n",
"38836 -0.907744 -0.863476 1.051267 -0.837490 0.176170 \n",
"30260 -0.934483 -0.889579 1.050478 -0.837490 0.876071 \n",
"33169 -1.050350 -1.002691 1.047532 -0.921885 0.036190 \n",
"1029 0.090496 0.154530 0.585622 -0.204531 -2.343471 \n",
"53809 0.295492 0.311147 0.949688 0.027554 -0.733700 \n",
"... ... ... ... ... ... \n",
"2937 0.063758 0.093624 0.680999 -0.056841 1.156031 \n",
"7514 0.429185 0.389455 1.102015 0.217442 0.036190 \n",
"48344 -0.408624 -0.350123 1.167088 -0.499912 0.246160 \n",
"3212 -0.052109 -0.089095 0.584874 -0.204531 0.036190 \n",
"35654 -1.255346 -1.246316 1.007245 -1.027378 0.946061 \n",
"\n",
" table z above_average_carat cut_Good cut_Ideal ... \\\n",
"id ... \n",
"38836 -0.648004 -0.857040 -0.856046 0.0 0.0 ... \n",
"30260 -0.201125 -0.814688 -0.856046 0.0 0.0 ... \n",
"33169 -0.648004 -1.012333 -0.856046 0.0 1.0 ... \n",
"1029 0.692631 -0.151165 -0.856046 0.0 0.0 ... \n",
"53809 -0.648004 0.215890 1.168162 0.0 0.0 ... \n",
"... ... ... ... ... ... ... \n",
"2937 -0.201125 0.215890 -0.856046 1.0 0.0 ... \n",
"7514 2.480145 0.413535 1.168162 1.0 0.0 ... \n",
"48344 -1.631136 -0.348810 -0.856046 0.0 1.0 ... \n",
"3212 1.139510 -0.066460 -0.856046 0.0 0.0 ... \n",
"35654 -0.201125 -1.153508 -0.856046 0.0 0.0 ... \n",
"\n",
" color_I color_J clarity_IF clarity_SI1 clarity_SI2 clarity_VS1 \\\n",
"id \n",
"38836 0.0 0.0 0.0 0.0 0.0 0.0 \n",
"30260 0.0 0.0 0.0 1.0 0.0 0.0 \n",
"33169 0.0 0.0 0.0 0.0 0.0 1.0 \n",
"1029 0.0 0.0 0.0 0.0 0.0 1.0 \n",
"53809 0.0 0.0 0.0 1.0 0.0 0.0 \n",
"... ... ... ... ... ... ... \n",
"2937 0.0 0.0 0.0 0.0 0.0 0.0 \n",
"7514 0.0 0.0 0.0 1.0 0.0 0.0 \n",
"48344 0.0 0.0 0.0 0.0 0.0 0.0 \n",
"3212 0.0 0.0 0.0 0.0 0.0 0.0 \n",
"35654 0.0 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" clarity_VS2 clarity_VVS1 clarity_VVS2 price \n",
"id \n",
"38836 0.0 0.0 1.0 1049 \n",
"30260 0.0 0.0 0.0 725 \n",
"33169 0.0 0.0 0.0 817 \n",
"1029 0.0 0.0 0.0 2904 \n",
"53809 0.0 0.0 0.0 2733 \n",
"... ... ... ... ... \n",
"2937 1.0 0.0 0.0 3291 \n",
"7514 0.0 0.0 0.0 4241 \n",
"48344 0.0 1.0 0.0 1961 \n",
"3212 0.0 1.0 0.0 3348 \n",
"35654 0.0 0.0 1.0 907 \n",
"\n",
"[43154 rows x 26 columns]"
]
},
"execution_count": 193,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование набора моделей для классификации\n",
"\n",
"logistic -- логистическая регрессия\n",
"\n",
"ridge -- гребневая регрессия\n",
"\n",
"decision_tree -- дерево решений\n",
"\n",
"knn -- k-ближайших соседей\n",
"\n",
"naive_bayes -- наивный Байесовский классификатор\n",
"\n",
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
"\n",
"random_forest -- метод случайного леса (набор деревьев решений)\n",
"\n",
"mlp -- многослойный персептрон (нейронная сеть)"
]
},
{
"cell_type": "code",
"execution_count": 194,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение моделей на обучающем наборе данных и оценка на тестовом"
]
},
{
"cell_type": "code",
"execution_count": 195,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Сводная таблица оценок качества для использованных моделей классификации"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Матрица неточностей"
]
},
{
"cell_type": "code",
"execution_count": 197,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0cAAAQ9CAYAAACSpDaqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeXxM5/4H8M+ZSWayTjayERGCSu10UWstCdXi0rpVSqylsUUt1dpCiepVS2sptbb81FXtVbsi1FLXXkvEFhJLYokkEklmO78/cjNMk0wyTMxkzud9X+d15TzPnHlmKufje56zCKIoiiAiIiIiIpI4mbUHQEREREREZAtYHBEREREREYHFEREREREREQAWR0RERERERABYHBEREREREQFgcURERERERASAxREREREREREAFkdEREREREQAWBwREREREREBYHFEz2jVqlUQBAHXr18vk+1fv34dgiBg1apVFtleXFwcBEFAXFycRbZHRERkL6ZOnQpBEErVVxAETJ06tWwHRGRFLI7IrixatMhiBRURERERSYuDtQdAVJTg4GDk5OTA0dHRrNctWrQIFSpUQGRkpNH6li1bIicnBwqFwoKjJCIiKv8mTpyITz/91NrDILIJLI7IJgmCACcnJ4ttTyaTWXR7RERE9iA7Oxuurq5wcOA/CYkAnlZHFrRo0SK8/PLLUCqVCAwMRFRUFNLT0wv1W7hwIapVqwZnZ2e8+uqr+OOPP9C6dWu0bt3a0Keoa45SUlLQr18/VK5cGUqlEgEBAejSpYvhuqeqVavi/Pnz2L9/PwRBgCAIhm0Wd83R0aNH8dZbb8HLywuurq6oV68e5s+fb9kvhoiIyAYUXFt04cIFfPDBB/Dy8kLz5s2LvOYoLy8P0dHRqFixItzd3dG5c2fcvHmzyO3GxcWhSZMmcHJyQvXq1fHdd98Vex3Tjz/+iMaNG8PZ2Rne3t54//33kZycXCafl+hZ8DABWcTUqVMRExODdu3aYejQoUhISMDixYtx7NgxHDp0yHB63OLFizFs2DC0aNEC0dHRuH79Orp27QovLy9UrlzZ5Ht0794d58+fx/Dhw1G1alXcvXsXu3fvRlJSEqpWrYp58+Zh+PDhcHNzw+effw4A8PPzK3Z7u3fvxttvv42AgACMHDkS/v7+iI+Px5YtWzBy5EjLfTlEREQ25L333kONGjUwc+ZMiKKIu3fvFuozcOBA/Pjjj/jggw/wxhtvYO/evejUqVOhfqdOnUKHDh0QEBCAmJgY6HQ6TJs2DRUrVizUd8aMGZg0aRJ69OiBgQMH4t69e/jmm2/QsmVLnDp1Cp6enmXxcYnMIxI9g5UrV4oAxMTERPHu3buiQqEQw8PDRZ1OZ+jz7bffigDEFStWiKIoinl5eaKPj4/4yiuviBqNxtBv1apVIgCxVatWhnWJiYkiAHHlypWiKIriw4cPRQDiV199ZXJcL7/8stF2Cuzbt08EIO7bt08URVHUarViSEiIGBwcLD58+NCor16vL/0XQUREVE5MmTJFBCD27NmzyPUFTp8+LQIQP/74Y6N+H3zwgQhAnDJlimHdO++8I7q4uIi3bt0yrLt8+bLo4OBgtM3r16+LcrlcnDFjhtE2z549Kzo4OBRaT2QtPK2Ontvvv/8OtVqNUaNGQSZ78ldq0KBBUKlU2Lp1KwDg+PHjePDgAQYNGmR0bnOvXr3g5eVl8j2cnZ2hUCgQFxeHhw8fPveYT506hcTERIwaNarQkarS3s6UiIioPBoyZIjJ9m3btgEARowYYbR+1KhRRj/rdDr8/vvv6Nq1KwIDAw3rQ0ND0bFjR6O+mzZtgl6vR48ePXD//n3D4u/vjxo1amDfvn3P8YmILIen1dFzu3HjBgCgVq1aRusVCgWqVatmaC/4/9DQUKN+Dg4OqFq1qsn3UCqV+PLLL/HJJ5/Az88Pr7/+Ot5++2306dMH/v7+Zo/56tWrAIA6deqY/VoiIqLyLCQkxGT7jRs3IJPJUL16daP1f8/5u3fvIicnp1CuA4Wz/vLlyxBFETVq1CjyPc29Oy1RWWFxROXGqFGj8M477+DXX3/Fzp07MWnSJMTGxmLv3r1o2LChtYdHRERULjg7O7/w99Tr9RAEAdu3b4dcLi/U7ubm9sLHRFQUnlZHzy04OBgAkJCQYLRerVYjMTHR0F7w/1euXDHqp9VqDXecK0n16tXxySefYNeuXTh37hzUajXmzJljaC/tKXEFR8POnTtXqv5ERERSERwcDL1ebzjLosDfc97X1xdOTk6Fch0onPXVq1eHKIoICQlBu3btCi2vv/665T8I0TNgcUTPrV27dlAoFFiwYAFEUTSsX758OTIyMgx3t2nSpAl8fHywbNkyaLVaQ7+1a9eWeB3R48ePkZuba7SuevXqcHd3R15enmGdq6trkbcP/7tGjRohJCQE8+bNK9T/6c9AREQkNQXXCy1YsMBo/bx584x+lsvlaNeuHX799Vfcvn3bsP7KlSvYvn27Ud9u3bpBLpcjJiamUM6KoogHDx5Y8BMQPTueVkfPrWLFipgwYQJiYmLQoUMHdO7cGQkJCVi0aBFeeeUV9O7dG0D+NUhTp07F8OHD0aZNG/To0QPXr1/HqlWrUL16dZOzPpcuXULbtm3Ro0cPhIWFwcHBAb/88gtSU1Px/vvvG/o1btwYixcvxhdffIHQ0FD4+vqiTZs2hbYnk8mwePFivPPOO2jQoAH69euHgIAAXLx4EefPn8fOnTst/0URERGVAw0aNEDPnj2xaNEiZGRk4I033sCePXuKnCGaOnUqdu3ahWbNmmHo0KHQ6XT49ttvUadOHZw+fdrQr3r16vjiiy8wYcIEw2M83N3dkZiYiF9++QWDBw/GmDFjXuCnJCoaiyOyiKlTp6JixYr49ttvER0dDW9vbwwePBgzZ840ushy2LBhEEURc+bMwZgxY1C/fn1s3rwZI0aMgJOTU7HbDwoKQs+ePbFnzx788MMPcHBwwEsvvYQNGzage/fuhn6TJ0/GjRs3MHv2bDx69AitWrUqsjgCgIiICOzbtw8xMTGYM2cO9Ho9qlevjkGDBlnuiyEiIiqHVqxYgYoVK2Lt2rX49ddf0aZNG2zduhVBQUFG/Ro3bozt27djzJgxmDRpEoKCgjBt2jTEx8fj4sWLRn0//fRT1KxZE3PnzkVMTAyA/HwPDw9H586dX9hnIzJFEHkOEVmZXq9HxYoV0a1bNyxbtszawyEiIqLn1LVrV5w/fx6XL1+29lCIzMJrjuiFys3NLXSu8Zo1a5CWlobWrVtbZ1BERET0zHJycox+vnz5MrZt28Zcp3KJM0f0QsXFxSE6OhrvvfcefHx8cPLkSSxfvhy1a9fGiRMnoFAorD1EIiIiMkNAQAAiIyMNzzZcvHgx8vLycOrUqWKfa0Rkq3jNEb1QVatWRVBQEBYsWIC0tDR4e3ujT58+mDVrFgsjIiKicqhDhw74v//7P6SkpECpVKJp06aYOXMmCyMqlzhzREREREREBF5zREREREREBIDFEREREREREQBec1Qqer0et2/fhru7u8kHlRLZI1EU8ejRIwQGBkIms+zxlNzcXKjV6hL7KRQKk8/BIiLpYTaTlDGbyw6Lo1K4fft2oYeeEUlNcnIyKleubLHt5ebmIiTYDSl3dSX29ff3R2Jiol3uhIno2TCbiZjNZYHFUSm4u7sDAG6crAqVG89EtIZ/1Kxr7SFIlhYaHMQ2w++BpajVaqTc1eHK8SCo3Iv/vcp8pEdok2So1Wq72wET0bNjNlsfs9l6mM1lh8VRKRRM16vcZCb/olDZcRAcrT0E6frf/SzL6rQVN3cBbu7Fb1sPni5DRIUxm62P2WxFzOYyw70JEVmVRtSVuJjr1q1b6N27N3x8fODs7Iy6devi+PHjhnZRFDF58mQEBATA2dkZ7dq1w+XLl422kZaWhl69ekGlUsHT0xMDBgxAVlaWUZ+//voLLVq0gJOTE4KCgjB79uxn+xKIiIhsiJSzmcUREVmVHmKJizkePnyIZs2awdHREdu3b8eFCxcwZ84ceHl5GfrMnj0
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Less\", \"More\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
"execution_count": 198,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_809a5_row0_col0, #T_809a5_row0_col1, #T_809a5_row0_col2, #T_809a5_row0_col3, #T_809a5_row1_col0, #T_809a5_row1_col1, #T_809a5_row1_col2, #T_809a5_row1_col3, #T_809a5_row2_col0, #T_809a5_row2_col1, #T_809a5_row2_col2, #T_809a5_row2_col3, #T_809a5_row3_col1, #T_809a5_row3_col2, #T_809a5_row3_col3, #T_809a5_row4_col0, #T_809a5_row4_col1, #T_809a5_row4_col2, #T_809a5_row4_col3, #T_809a5_row5_col0, #T_809a5_row5_col1, #T_809a5_row5_col2, #T_809a5_row5_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_809a5_row0_col4, #T_809a5_row0_col5, #T_809a5_row0_col6, #T_809a5_row0_col7, #T_809a5_row1_col4, #T_809a5_row1_col5, #T_809a5_row1_col6, #T_809a5_row1_col7, #T_809a5_row2_col4, #T_809a5_row2_col5, #T_809a5_row2_col6, #T_809a5_row2_col7, #T_809a5_row3_col4, #T_809a5_row3_col5, #T_809a5_row3_col6, #T_809a5_row3_col7, #T_809a5_row4_col4, #T_809a5_row4_col5, #T_809a5_row4_col6, #T_809a5_row4_col7, #T_809a5_row5_col4, #T_809a5_row5_col5, #T_809a5_row5_col6, #T_809a5_row5_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_809a5_row3_col0, #T_809a5_row6_col2 {\n",
" background-color: #a5db36;\n",
" color: #000000;\n",
"}\n",
"#T_809a5_row6_col0 {\n",
" background-color: #a0da39;\n",
" color: #000000;\n",
"}\n",
"#T_809a5_row6_col1, #T_809a5_row6_col3 {\n",
" background-color: #a2da37;\n",
" color: #000000;\n",
"}\n",
"#T_809a5_row6_col4, #T_809a5_row6_col5, #T_809a5_row6_col6, #T_809a5_row6_col7 {\n",
" background-color: #d8576b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_809a5_row7_col0, #T_809a5_row7_col1, #T_809a5_row7_col2, #T_809a5_row7_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_809a5_row7_col4, #T_809a5_row7_col5, #T_809a5_row7_col6, #T_809a5_row7_col7 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_809a5\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_809a5_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_809a5_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_809a5_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_809a5_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_809a5_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_809a5_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_809a5_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_809a5_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_809a5_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_809a5_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_809a5_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_809a5_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_809a5_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_809a5_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_809a5_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_809a5_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_809a5_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_809a5_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_809a5_row1_col0\" class=\"data row1 col0\" >0.999945</td>\n",
" <td id=\"T_809a5_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_809a5_row1_col2\" class=\"data row1 col2\" >0.999945</td>\n",
" <td id=\"T_809a5_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_809a5_row1_col4\" class=\"data row1 col4\" >0.999954</td>\n",
" <td id=\"T_809a5_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_809a5_row1_col6\" class=\"data row1 col6\" >0.999945</td>\n",
" <td id=\"T_809a5_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_809a5_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_809a5_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_809a5_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_809a5_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_809a5_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_809a5_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" <td id=\"T_809a5_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
" <td id=\"T_809a5_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
" <td id=\"T_809a5_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_809a5_level0_row3\" class=\"row_heading level0 row3\" >naive_bayes</th>\n",
" <td id=\"T_809a5_row3_col0\" class=\"data row3 col0\" >0.999890</td>\n",
" <td id=\"T_809a5_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_809a5_row3_col2\" class=\"data row3 col2\" >0.999890</td>\n",
" <td id=\"T_809a5_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_809a5_row3_col4\" class=\"data row3 col4\" >0.999907</td>\n",
" <td id=\"T_809a5_row3_col5\" class=\"data row3 col5\" >1.000000</td>\n",
" <td id=\"T_809a5_row3_col6\" class=\"data row3 col6\" >0.999890</td>\n",
" <td id=\"T_809a5_row3_col7\" class=\"data row3 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_809a5_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
" <td id=\"T_809a5_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_809a5_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_809a5_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_809a5_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_809a5_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
" <td id=\"T_809a5_row4_col5\" class=\"data row4 col5\" >1.000000</td>\n",
" <td id=\"T_809a5_row4_col6\" class=\"data row4 col6\" >1.000000</td>\n",
" <td id=\"T_809a5_row4_col7\" class=\"data row4 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_809a5_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
" <td id=\"T_809a5_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
" <td id=\"T_809a5_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
" <td id=\"T_809a5_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
" <td id=\"T_809a5_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
" <td id=\"T_809a5_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
" <td id=\"T_809a5_row5_col5\" class=\"data row5 col5\" >1.000000</td>\n",
" <td id=\"T_809a5_row5_col6\" class=\"data row5 col6\" >1.000000</td>\n",
" <td id=\"T_809a5_row5_col7\" class=\"data row5 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_809a5_level0_row6\" class=\"row_heading level0 row6\" >mlp</th>\n",
" <td id=\"T_809a5_row6_col0\" class=\"data row6 col0\" >0.999507</td>\n",
" <td id=\"T_809a5_row6_col1\" class=\"data row6 col1\" >0.999562</td>\n",
" <td id=\"T_809a5_row6_col2\" class=\"data row6 col2\" >0.999836</td>\n",
" <td id=\"T_809a5_row6_col3\" class=\"data row6 col3\" >0.999562</td>\n",
" <td id=\"T_809a5_row6_col4\" class=\"data row6 col4\" >0.999722</td>\n",
" <td id=\"T_809a5_row6_col5\" class=\"data row6 col5\" >0.999629</td>\n",
" <td id=\"T_809a5_row6_col6\" class=\"data row6 col6\" >0.999671</td>\n",
" <td id=\"T_809a5_row6_col7\" class=\"data row6 col7\" >0.999562</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_809a5_level0_row7\" class=\"row_heading level0 row7\" >knn</th>\n",
" <td id=\"T_809a5_row7_col0\" class=\"data row7 col0\" >0.983970</td>\n",
" <td id=\"T_809a5_row7_col1\" class=\"data row7 col1\" >0.979300</td>\n",
" <td id=\"T_809a5_row7_col2\" class=\"data row7 col2\" >0.978740</td>\n",
" <td id=\"T_809a5_row7_col3\" class=\"data row7 col3\" >0.974578</td>\n",
" <td id=\"T_809a5_row7_col4\" class=\"data row7 col4\" >0.984266</td>\n",
" <td id=\"T_809a5_row7_col5\" class=\"data row7 col5\" >0.980536</td>\n",
" <td id=\"T_809a5_row7_col6\" class=\"data row7 col6\" >0.981348</td>\n",
" <td id=\"T_809a5_row7_col7\" class=\"data row7 col7\" >0.976933</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x21f4db75340>"
]
},
"execution_count": 198,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
"execution_count": 199,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_b558d_row0_col0, #T_b558d_row0_col1, #T_b558d_row1_col0, #T_b558d_row1_col1, #T_b558d_row2_col0, #T_b558d_row2_col1, #T_b558d_row3_col0, #T_b558d_row3_col1, #T_b558d_row4_col0, #T_b558d_row4_col1, #T_b558d_row5_col0, #T_b558d_row5_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_b558d_row0_col2, #T_b558d_row0_col3, #T_b558d_row0_col4, #T_b558d_row1_col2, #T_b558d_row1_col3, #T_b558d_row1_col4, #T_b558d_row2_col2, #T_b558d_row2_col3, #T_b558d_row2_col4, #T_b558d_row3_col2, #T_b558d_row3_col3, #T_b558d_row3_col4, #T_b558d_row4_col2, #T_b558d_row4_col3, #T_b558d_row4_col4, #T_b558d_row5_col2, #T_b558d_row5_col3, #T_b558d_row5_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b558d_row6_col0, #T_b558d_row6_col1 {\n",
" background-color: #a2da37;\n",
" color: #000000;\n",
"}\n",
"#T_b558d_row6_col2 {\n",
" background-color: #d45270;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b558d_row6_col3, #T_b558d_row6_col4 {\n",
" background-color: #d8576b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b558d_row7_col0, #T_b558d_row7_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b558d_row7_col2, #T_b558d_row7_col3, #T_b558d_row7_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_b558d\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_b558d_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_b558d_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_b558d_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_b558d_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_b558d_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_b558d_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_b558d_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_b558d_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_b558d_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_b558d_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_b558d_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_b558d_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_b558d_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_b558d_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_b558d_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_b558d_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_b558d_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_b558d_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_b558d_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_b558d_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_b558d_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_b558d_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_b558d_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_b558d_level0_row3\" class=\"row_heading level0 row3\" >naive_bayes</th>\n",
" <td id=\"T_b558d_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_b558d_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_b558d_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_b558d_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_b558d_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_b558d_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
" <td id=\"T_b558d_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_b558d_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_b558d_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_b558d_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_b558d_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_b558d_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
" <td id=\"T_b558d_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
" <td id=\"T_b558d_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
" <td id=\"T_b558d_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
" <td id=\"T_b558d_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
" <td id=\"T_b558d_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_b558d_level0_row6\" class=\"row_heading level0 row6\" >mlp</th>\n",
" <td id=\"T_b558d_row6_col0\" class=\"data row6 col0\" >0.999629</td>\n",
" <td id=\"T_b558d_row6_col1\" class=\"data row6 col1\" >0.999562</td>\n",
" <td id=\"T_b558d_row6_col2\" class=\"data row6 col2\" >0.999754</td>\n",
" <td id=\"T_b558d_row6_col3\" class=\"data row6 col3\" >0.999240</td>\n",
" <td id=\"T_b558d_row6_col4\" class=\"data row6 col4\" >0.999240</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_b558d_level0_row7\" class=\"row_heading level0 row7\" >knn</th>\n",
" <td id=\"T_b558d_row7_col0\" class=\"data row7 col0\" >0.980536</td>\n",
" <td id=\"T_b558d_row7_col1\" class=\"data row7 col1\" >0.976933</td>\n",
" <td id=\"T_b558d_row7_col2\" class=\"data row7 col2\" >0.995960</td>\n",
" <td id=\"T_b558d_row7_col3\" class=\"data row7 col3\" >0.960098</td>\n",
" <td id=\"T_b558d_row7_col4\" class=\"data row7 col4\" >0.960107</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x21f49b924b0>"
]
},
"execution_count": 199,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 200,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
"execution_count": 206,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 0'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>carat</th>\n",
" <th>Predicted</th>\n",
" <th>cut</th>\n",
" <th>color</th>\n",
" <th>clarity</th>\n",
" <th>depth</th>\n",
" <th>table</th>\n",
" <th>price</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>z</th>\n",
" <th>above_average_carat</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [carat, Predicted, cut, color, clarity, depth, table, price, x, y, z, above_average_carat]\n",
"Index: []"
]
},
"execution_count": 206,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"above_average_carat\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Пример использования обученной модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
"execution_count": 208,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>carat</th>\n",
" <th>cut</th>\n",
" <th>color</th>\n",
" <th>clarity</th>\n",
" <th>depth</th>\n",
" <th>table</th>\n",
" <th>price</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>z</th>\n",
" <th>above_average_carat</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>4500</th>\n",
" <td>0.9</td>\n",
" <td>Premium</td>\n",
" <td>H</td>\n",
" <td>SI1</td>\n",
" <td>61.9</td>\n",
" <td>58.0</td>\n",
" <td>3629</td>\n",
" <td>6.2</td>\n",
" <td>6.15</td>\n",
" <td>3.82</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" carat cut color clarity depth table price x y z \\\n",
"4500 0.9 Premium H SI1 61.9 58.0 3629 6.2 6.15 3.82 \n",
"\n",
" above_average_carat \n",
"4500 1 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>Length_to_Width_Ratio</th>\n",
" <th>carat</th>\n",
" <th>depth</th>\n",
" <th>table</th>\n",
" <th>z</th>\n",
" <th>above_average_carat</th>\n",
" <th>cut_Good</th>\n",
" <th>cut_Ideal</th>\n",
" <th>...</th>\n",
" <th>color_I</th>\n",
" <th>color_J</th>\n",
" <th>clarity_IF</th>\n",
" <th>clarity_SI1</th>\n",
" <th>clarity_SI2</th>\n",
" <th>clarity_VS1</th>\n",
" <th>clarity_VS2</th>\n",
" <th>clarity_VVS1</th>\n",
" <th>clarity_VVS2</th>\n",
" <th>price</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>4500</th>\n",
" <td>0.420272</td>\n",
" <td>0.363352</td>\n",
" <td>1.156653</td>\n",
" <td>0.217442</td>\n",
" <td>0.10618</td>\n",
" <td>0.245753</td>\n",
" <td>0.399417</td>\n",
" <td>1.168162</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>3629.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1 rows × 26 columns</p>\n",
"</div>"
],
"text/plain": [
" x y Length_to_Width_Ratio carat depth table \\\n",
"4500 0.420272 0.363352 1.156653 0.217442 0.10618 0.245753 \n",
"\n",
" z above_average_carat cut_Good cut_Ideal ... color_I \\\n",
"4500 0.399417 1.168162 0.0 0.0 ... 0.0 \n",
"\n",
" color_J clarity_IF clarity_SI1 clarity_SI2 clarity_VS1 clarity_VS2 \\\n",
"4500 0.0 0.0 1.0 0.0 0.0 0.0 \n",
"\n",
" clarity_VVS1 clarity_VVS2 price \n",
"4500 0.0 0.0 3629.0 \n",
"\n",
"[1 rows x 26 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: 1 (proba: [4.76016150e-04 9.99523984e-01])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 1'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 4500\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
"execution_count": 209,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 2,\n",
" 'model__max_features': 'sqrt',\n",
" 'model__n_estimators': 20}"
]
},
"execution_count": 209,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"optimized_model_type = \"random_forest\"\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
" \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10],\n",
" \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
"execution_count": 210,
"metadata": {},
"outputs": [],
"source": [
"optimized_model = ensemble.RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=7,\n",
" max_features=\"sqrt\",\n",
" n_estimators=30,\n",
")\n",
"\n",
"result = {}\n",
"\n",
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
"execution_count": 211,
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
"execution_count": 212,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_703b0_row0_col0, #T_703b0_row0_col1, #T_703b0_row0_col2, #T_703b0_row0_col3, #T_703b0_row1_col0, #T_703b0_row1_col1, #T_703b0_row1_col2, #T_703b0_row1_col3 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_703b0_row0_col4, #T_703b0_row0_col5, #T_703b0_row0_col6, #T_703b0_row0_col7, #T_703b0_row1_col4, #T_703b0_row1_col5, #T_703b0_row1_col6, #T_703b0_row1_col7 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_703b0\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_703b0_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_703b0_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_703b0_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_703b0_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_703b0_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_703b0_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_703b0_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_703b0_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" <th class=\"blank col7\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_703b0_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_703b0_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_703b0_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_703b0_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_703b0_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_703b0_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_703b0_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_703b0_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_703b0_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_703b0_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_703b0_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_703b0_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_703b0_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_703b0_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_703b0_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_703b0_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_703b0_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_703b0_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x21f4cb4fe90>"
]
},
"execution_count": 212,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 213,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_0ba78_row0_col0, #T_0ba78_row0_col1, #T_0ba78_row1_col0, #T_0ba78_row1_col1 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0ba78_row0_col2, #T_0ba78_row0_col3, #T_0ba78_row0_col4, #T_0ba78_row1_col2, #T_0ba78_row1_col3, #T_0ba78_row1_col4 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_0ba78\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_0ba78_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_0ba78_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_0ba78_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_0ba78_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_0ba78_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_0ba78_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_0ba78_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_0ba78_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_0ba78_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_0ba78_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_0ba78_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0ba78_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_0ba78_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_0ba78_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_0ba78_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_0ba78_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_0ba78_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x21f4cb4ef30>"
]
},
"execution_count": 213,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 215,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2kAAAGsCAYAAABHMu+IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABRUElEQVR4nO3dfVwVZf7/8fcBPIDAAVEBSSSMvGHz3k1ZLbMQNLcybdtaNU3T1cVKzZvcLe8qbW3NshutTNHd/Jnb3Te1NLNVU8nUsszUvCssBSsDROX2nN8fxqnjXYMMnOH4ej4e80hmLq5zDfng7Weua2ZsLpfLJQAAAACAJfh5ewAAAAAAgF9QpAEAAACAhVCkAQAAAICFUKQBAAAAgIVQpAEAAACAhVCkAQAAAICFUKQBAAAAgIVQpAEAAACAhQR4ewAAgHMrLCxUcXGxaf3Z7XYFBQWZ1h8AABVBrhlHkQYAFlRYWKiE+FBlHy0zrc+YmBgdPHjQZwMNAGBd5FrFUKQBgAUVFxcr+2iZDm6LlyOs8ivT8487ldDuGxUXF/tkmAEArI1cqxiKNACwMEeYnylhBgCAFZBrxlCkAYCFlbmcKnOZ0w8AAN5GrhlDkQYAFuaUS05VPs3M6AMAgMoi14xhrhEAAAAALISZNACwMKecMmNBhzm9AABQOeSaMRRpAGBhZS6XylyVX9JhRh8AAFQWuWYMyx0BAAAAwEKYSQMAC+MGawCALyHXjGEmDQAszCmXykzYLibMvvvuO/Xr109169ZVcHCwWrRooa1bt7qPu1wuTZw4UQ0aNFBwcLBSUlK0d+9ejz6OHTumvn37yuFwKCIiQoMHD1ZBQYFHm88//1zXXHONgoKCFBcXpxkzZlzcDwsAYHnkmjEUaQCAs/z000/q1KmTatWqpXfffVdffvmlZs6cqTp16rjbzJgxQ7Nnz9bcuXO1efNmhYSEKC0tTYWFhe42ffv21c6dO7V69WotX75c69ev19ChQ93H8/PzlZqaqvj4eG3btk1PPPGEJk+erBdffLFazxcA4NtqWq7ZXC4fv+sOAGqg/Px8hYeHa//uGIWFVf562vHjTl3RLFt5eXlyOBy/2f7BBx/Uxo0b9eGHH57zuMvlUmxsrB544AGNGTNGkpSXl6fo6GhlZGTojjvu0K5du5SUlKQtW7aoffv2kqSVK1fqxhtv1LfffqvY2FjNmTNH//jHP5SdnS273e7+7Lfeeku7d++u9HkDAKyBXKtYrjGTBgCXkPz8fI+tqKjonO3efvtttW/fXn/6058UFRWlNm3a6KWXXnIfP3jwoLKzs5WSkuLeFx4erg4dOigzM1OSlJmZqYiICHeQSVJKSor8/Py0efNmd5trr73WHWSSlJaWpj179uinn34y9dwBAL7HV3ONIg0ALKz8UcVmbJIUFxen8PBw9zZ9+vRzfu6BAwc0Z84cXXnllVq1apWGDx+u++67TwsXLpQkZWdnS5Kio6M9vi86Otp9LDs7W1FRUR7HAwICFBkZ6dHmXH38+jMAAL6DXDOGpzsCgIU5f97M6EeSDh065LEsJDAw8NztnU61b99e06ZNkyS1adNGX3zxhebOnasBAwaYMCIAwKWIXDOGmTQAuIQ4HA6P7Xxh1qBBAyUlJXnsa968ubKysiRJMTExkqScnByPNjk5Oe5jMTExOnr0qMfx0tJSHTt2zKPNufr49WcAAHA+vpprFGkAYGFmPKa4fKuITp06ac+ePR77vvrqK8XHx0uSEhISFBMTozVr1riP5+fna/PmzUpOTpYkJScnKzc3V9u2bXO3+eCDD+R0OtWhQwd3m/Xr16ukpMTdZvXq1WratKnHE7cAAL6BXDOGIg0ALKzMZd5WEaNGjdJHH32kadOmad++fVq8eLFefPFFpaenS5JsNptGjhypRx99VG+//bZ27Nihu+66S7GxserVq5ek01cou3fvriFDhujjjz/Wxo0bNWLECN1xxx2KjY2VJP3lL3+R3W7X4MGDtXPnTr366qt6+umnNXr0aDN/jAAAiyDXjOER/ABgQeWPKv78yyjTHlXcMumo4UcVS9Ly5cs1YcIE7d27VwkJCRo9erSGDBniPu5yuTRp0iS9+OKLys3NVefOnfX888+rSZMm7jbHjh3TiBEjtGzZMvn5+alPnz6aPXu2QkND3W0+//xzpaena8uWLapXr57uvfdejR8/vtLnDACwDnKtYrlGkQYAFlQeZttNDLPWFQwzAADMQq5VDE93BAALc8qmMtlM6QcAAG8j14zhnjQAAAAAsBBm0gDAwpyu05sZ/QAA4G3kmjEUaQBgYWUmLQsxow8AACqLXDOG5Y4AAAAAYCHMpAGAhXHFEQDgS8g1Y5hJAwAAAAALYSYNACzM6bLJ6TLhUcUm9AEAQGWRa8ZQpAGAhbEsBADgS8g1Y1juCAAAAAAWwkwaAFhYmfxUZsL1tDITxgIAQGWRa8ZQpAGAhblMWrvv8vG1+wCAmoFcM4bljgAAAABgIcykAYCFcYM1AMCXkGvGUKQBgIWVufxU5jJh7b7LhMEAAFBJ5JoxLHcEAAAAAAthJg0ALMwpm5wmXE9zyscvOQIAagRyzRhm0gAAAADAQphJAwAL4wZrAIAvIdeMoUgDAAsz7wZr314WAgCoGcg1Y1juCAAAAAAWwkwaAFjY6RusK7+kw4w+AACoLHLNGIo0ALAwp/xUxlOwAAA+glwzhuWOAAAAAGAhzKQBgIVxgzUAwJeQa8ZQpAGAhTnlx0s/AQA+g1wzhuWOAAAAAGAhzKQBgIWVuWwqc5nw0k8T+gAAoLLINWMo0gDAwspMegpWmY8vCwEA1AzkmjEsdwQAAAAAC2EmDQAszOnyk9OEp2A5ffwpWACAmoFcM4aZNAAAAACwEGbSAMDCWLsPAPAl5JoxFGkAYGFOmfMEK2flhwIAQKWRa8aw3BEAAAAALISZNACwMKf85DThepoZfQAAUFnkmjEUaQBgYWUuP5WZ8BQsM/oAAKCyyDVjfPvsAAAAAKCGYSYNACzMKZucMuMG68r3AQBAZZFrxlCkAYCFsSwEAOBLyDVjfPvsAAAAAKCGYSYNACzMvJd+ck0OAOB95Joxvn12AAAAAFDDMJNmgNPp1OHDhxUWFiabzbdvUgRQeS6XS8ePH1dsbKz8/Cp3LczpssnpMuEGaxP6gO8g1wBUBLlW/SjSDDh8+LDi4uK8PQwANcyhQ4fUsGHDSvXhNGlZiK+/9BMVQ64BuBjkWvWhSDMgLCxMkvTNJ5fLEerbfyFQcbc2aeHtIcBiSlWiDXrH/bsDsBpyDRdCruFM5Fr1o0gzoHwpiCPUT44wwgyeAmy1vD0EWI3r9H/MWEbmdPnJacJjhs3oA76DXMOFkGs4C7lW7SjSAMDCymRTmQkv7DSjDwAAKotcM8a3S1AAAAAAqGGYSQMAC2NZCADAl5BrxlCkAYCFlcmcJR1llR8KAACVRq4Z49slKAAAAADUMMykAYCFsSwEAOBLyDVjfPvsAKCGK3P5mbZVxOTJk2Wz2Ty2Zs2auY8XFhYqPT1ddevWVWhoqPr06aOcnByPPrKystSzZ0/Vrl1bUVFRGjt2rEpLSz3arF27Vm3btlVgYKASExOVkZFx0T8rAID1kWvGUKQBAM7pd7/7nY4cOeLeNmzY4D42atQoLVu2TP/973+1bt06HT58WL1793YfLysrU8+ePVVcXKxNmzZp4cKFysjI0MSJE91tDh48qJ49e6pr167avn27Ro4cqXvuuUerVq2q1vMEAFwaalKusdwRACzMJZucJtxg7bqIPgICAhQTE3PW/ry8PL388stavHixrr/+eknSggUL1Lx5c3300Ufq2LGj3nvvPX355Zd6//33FR0drdatW+uRRx7R+PHjNXnyZNntds2dO1cJCQmaOXOmJKl58+basGGDZs2apbS0tMqdMADAksg1Y5hJA4BLSH5+vsdWVFR03rZ79+5VbGysGjdurL59+yorK0uStG3bNpWUlCglJcXdtlmzZmrUqJEyMzMlSZmZmWrRooWio6PdbdLS0pSfn6+dO3e62/y6j/I25X0
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Less\", \"More\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}