2887 lines
258 KiB
Plaintext
2887 lines
258 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Лабораторная работа 4\n",
|
||
"\n",
|
||
"Датасет - **Цены на бриллианты**\thttps://www.kaggle.com/datasets/nancyalaswad90/diamonds-prices\n",
|
||
"\n",
|
||
"1. **carat**: Вес бриллианта в каратах\n",
|
||
"2. **cut**: Качество огранки.\n",
|
||
"3. **color**: Цвет бриллианта\n",
|
||
"4. **clarity**: Чистота бриллианта\n",
|
||
"5. **depth**: Процент глубины бриллианта\n",
|
||
"6. **table**: Процент ширины бриллианта\n",
|
||
"7. **price**: Цена бриллианта в долларах США\n",
|
||
"8. **x**: Длина бриллианта в миллиметрах\n",
|
||
"9. **y**: Ширина бриллианта в миллиметрах\n",
|
||
"10. **z**: Глубина бриллианта в миллиметрах"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Бизнес-цели**: \n",
|
||
"1. Прогнозирование цены бриллиантов на основании характеристик.\n",
|
||
"2. Анализ частотности и сочетания характеристик бриллиантов, которые пользуются наибольшим спросом, чтобы лучше планировать запасы. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Загрузка набора данных"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Среднее значение поля 'карат': 0.7979346717831785\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>carat</th>\n",
|
||
" <th>cut</th>\n",
|
||
" <th>color</th>\n",
|
||
" <th>clarity</th>\n",
|
||
" <th>depth</th>\n",
|
||
" <th>table</th>\n",
|
||
" <th>price</th>\n",
|
||
" <th>x</th>\n",
|
||
" <th>y</th>\n",
|
||
" <th>z</th>\n",
|
||
" <th>above_average_carat</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>id</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>0.23</td>\n",
|
||
" <td>Ideal</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>SI2</td>\n",
|
||
" <td>61.5</td>\n",
|
||
" <td>55.0</td>\n",
|
||
" <td>326</td>\n",
|
||
" <td>3.95</td>\n",
|
||
" <td>3.98</td>\n",
|
||
" <td>2.43</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>0.21</td>\n",
|
||
" <td>Premium</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>SI1</td>\n",
|
||
" <td>59.8</td>\n",
|
||
" <td>61.0</td>\n",
|
||
" <td>326</td>\n",
|
||
" <td>3.89</td>\n",
|
||
" <td>3.84</td>\n",
|
||
" <td>2.31</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>0.23</td>\n",
|
||
" <td>Good</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>VS1</td>\n",
|
||
" <td>56.9</td>\n",
|
||
" <td>65.0</td>\n",
|
||
" <td>327</td>\n",
|
||
" <td>4.05</td>\n",
|
||
" <td>4.07</td>\n",
|
||
" <td>2.31</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>0.29</td>\n",
|
||
" <td>Premium</td>\n",
|
||
" <td>I</td>\n",
|
||
" <td>VS2</td>\n",
|
||
" <td>62.4</td>\n",
|
||
" <td>58.0</td>\n",
|
||
" <td>334</td>\n",
|
||
" <td>4.20</td>\n",
|
||
" <td>4.23</td>\n",
|
||
" <td>2.63</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>5</th>\n",
|
||
" <td>0.31</td>\n",
|
||
" <td>Good</td>\n",
|
||
" <td>J</td>\n",
|
||
" <td>SI2</td>\n",
|
||
" <td>63.3</td>\n",
|
||
" <td>58.0</td>\n",
|
||
" <td>335</td>\n",
|
||
" <td>4.34</td>\n",
|
||
" <td>4.35</td>\n",
|
||
" <td>2.75</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>53939</th>\n",
|
||
" <td>0.86</td>\n",
|
||
" <td>Premium</td>\n",
|
||
" <td>H</td>\n",
|
||
" <td>SI2</td>\n",
|
||
" <td>61.0</td>\n",
|
||
" <td>58.0</td>\n",
|
||
" <td>2757</td>\n",
|
||
" <td>6.15</td>\n",
|
||
" <td>6.12</td>\n",
|
||
" <td>3.74</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>53940</th>\n",
|
||
" <td>0.75</td>\n",
|
||
" <td>Ideal</td>\n",
|
||
" <td>D</td>\n",
|
||
" <td>SI2</td>\n",
|
||
" <td>62.2</td>\n",
|
||
" <td>55.0</td>\n",
|
||
" <td>2757</td>\n",
|
||
" <td>5.83</td>\n",
|
||
" <td>5.87</td>\n",
|
||
" <td>3.64</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>53941</th>\n",
|
||
" <td>0.71</td>\n",
|
||
" <td>Premium</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>SI1</td>\n",
|
||
" <td>60.5</td>\n",
|
||
" <td>55.0</td>\n",
|
||
" <td>2756</td>\n",
|
||
" <td>5.79</td>\n",
|
||
" <td>5.74</td>\n",
|
||
" <td>3.49</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>53942</th>\n",
|
||
" <td>0.71</td>\n",
|
||
" <td>Premium</td>\n",
|
||
" <td>F</td>\n",
|
||
" <td>SI1</td>\n",
|
||
" <td>59.8</td>\n",
|
||
" <td>62.0</td>\n",
|
||
" <td>2756</td>\n",
|
||
" <td>5.74</td>\n",
|
||
" <td>5.73</td>\n",
|
||
" <td>3.43</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>53943</th>\n",
|
||
" <td>0.70</td>\n",
|
||
" <td>Very Good</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>VS2</td>\n",
|
||
" <td>60.5</td>\n",
|
||
" <td>59.0</td>\n",
|
||
" <td>2757</td>\n",
|
||
" <td>5.71</td>\n",
|
||
" <td>5.76</td>\n",
|
||
" <td>3.47</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>53943 rows × 11 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" carat cut color clarity depth table price x y z \\\n",
|
||
"id \n",
|
||
"1 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43 \n",
|
||
"2 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31 \n",
|
||
"3 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31 \n",
|
||
"4 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63 \n",
|
||
"5 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75 \n",
|
||
"... ... ... ... ... ... ... ... ... ... ... \n",
|
||
"53939 0.86 Premium H SI2 61.0 58.0 2757 6.15 6.12 3.74 \n",
|
||
"53940 0.75 Ideal D SI2 62.2 55.0 2757 5.83 5.87 3.64 \n",
|
||
"53941 0.71 Premium E SI1 60.5 55.0 2756 5.79 5.74 3.49 \n",
|
||
"53942 0.71 Premium F SI1 59.8 62.0 2756 5.74 5.73 3.43 \n",
|
||
"53943 0.70 Very Good E VS2 60.5 59.0 2757 5.71 5.76 3.47 \n",
|
||
"\n",
|
||
" above_average_carat \n",
|
||
"id \n",
|
||
"1 0 \n",
|
||
"2 0 \n",
|
||
"3 0 \n",
|
||
"4 0 \n",
|
||
"5 0 \n",
|
||
"... ... \n",
|
||
"53939 1 \n",
|
||
"53940 0 \n",
|
||
"53941 0 \n",
|
||
"53942 0 \n",
|
||
"53943 0 \n",
|
||
"\n",
|
||
"[53943 rows x 11 columns]"
|
||
]
|
||
},
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"from sklearn import set_config\n",
|
||
"\n",
|
||
"set_config(transform_output=\"pandas\")\n",
|
||
"\n",
|
||
"df = pd.read_csv(\"data/Diamonds.csv\", index_col=\"id\")\n",
|
||
"\n",
|
||
"random_state=42\n",
|
||
"#считаем средний вес бриллиантов\n",
|
||
"average_carat = df['carat'].mean()\n",
|
||
"\n",
|
||
"print(f\"Среднее значение поля 'карат': {average_carat}\")\n",
|
||
"\n",
|
||
"#новый столбец, в котором 1 значит больше ср знач, 0 меньше\n",
|
||
"average_carat = df['carat'].mean()\n",
|
||
"df['above_average_carat'] = (df['carat'] > average_carat).astype(int)\n",
|
||
"df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
|
||
"\n",
|
||
"Целевой признак -- Cut"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 34,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>carat</th>\n",
|
||
" <th>cut</th>\n",
|
||
" <th>color</th>\n",
|
||
" <th>clarity</th>\n",
|
||
" <th>depth</th>\n",
|
||
" <th>table</th>\n",
|
||
" <th>price</th>\n",
|
||
" <th>x</th>\n",
|
||
" <th>y</th>\n",
|
||
" <th>z</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>id</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>38836</th>\n",
|
||
" <td>0.40</td>\n",
|
||
" <td>Very Good</td>\n",
|
||
" <td>F</td>\n",
|
||
" <td>VVS2</td>\n",
|
||
" <td>62.0</td>\n",
|
||
" <td>56.0</td>\n",
|
||
" <td>1049</td>\n",
|
||
" <td>4.71</td>\n",
|
||
" <td>4.74</td>\n",
|
||
" <td>2.93</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>30260</th>\n",
|
||
" <td>0.40</td>\n",
|
||
" <td>Very Good</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>SI1</td>\n",
|
||
" <td>63.0</td>\n",
|
||
" <td>57.0</td>\n",
|
||
" <td>725</td>\n",
|
||
" <td>4.68</td>\n",
|
||
" <td>4.71</td>\n",
|
||
" <td>2.96</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>33169</th>\n",
|
||
" <td>0.36</td>\n",
|
||
" <td>Ideal</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>VS1</td>\n",
|
||
" <td>61.8</td>\n",
|
||
" <td>56.0</td>\n",
|
||
" <td>817</td>\n",
|
||
" <td>4.55</td>\n",
|
||
" <td>4.58</td>\n",
|
||
" <td>2.82</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1029</th>\n",
|
||
" <td>0.70</td>\n",
|
||
" <td>Very Good</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>VS1</td>\n",
|
||
" <td>58.4</td>\n",
|
||
" <td>59.0</td>\n",
|
||
" <td>2904</td>\n",
|
||
" <td>5.83</td>\n",
|
||
" <td>5.91</td>\n",
|
||
" <td>3.43</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>53809</th>\n",
|
||
" <td>0.81</td>\n",
|
||
" <td>Very Good</td>\n",
|
||
" <td>G</td>\n",
|
||
" <td>SI1</td>\n",
|
||
" <td>60.7</td>\n",
|
||
" <td>56.0</td>\n",
|
||
" <td>2733</td>\n",
|
||
" <td>6.06</td>\n",
|
||
" <td>6.09</td>\n",
|
||
" <td>3.69</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2937</th>\n",
|
||
" <td>0.77</td>\n",
|
||
" <td>Good</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>VS2</td>\n",
|
||
" <td>63.4</td>\n",
|
||
" <td>57.0</td>\n",
|
||
" <td>3291</td>\n",
|
||
" <td>5.80</td>\n",
|
||
" <td>5.84</td>\n",
|
||
" <td>3.69</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>7514</th>\n",
|
||
" <td>0.90</td>\n",
|
||
" <td>Good</td>\n",
|
||
" <td>F</td>\n",
|
||
" <td>SI1</td>\n",
|
||
" <td>61.8</td>\n",
|
||
" <td>63.0</td>\n",
|
||
" <td>4241</td>\n",
|
||
" <td>6.21</td>\n",
|
||
" <td>6.18</td>\n",
|
||
" <td>3.83</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>48344</th>\n",
|
||
" <td>0.56</td>\n",
|
||
" <td>Ideal</td>\n",
|
||
" <td>H</td>\n",
|
||
" <td>VVS1</td>\n",
|
||
" <td>62.1</td>\n",
|
||
" <td>53.8</td>\n",
|
||
" <td>1961</td>\n",
|
||
" <td>5.27</td>\n",
|
||
" <td>5.33</td>\n",
|
||
" <td>3.29</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3212</th>\n",
|
||
" <td>0.70</td>\n",
|
||
" <td>Premium</td>\n",
|
||
" <td>F</td>\n",
|
||
" <td>VVS1</td>\n",
|
||
" <td>61.8</td>\n",
|
||
" <td>60.0</td>\n",
|
||
" <td>3348</td>\n",
|
||
" <td>5.67</td>\n",
|
||
" <td>5.63</td>\n",
|
||
" <td>3.49</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>35654</th>\n",
|
||
" <td>0.31</td>\n",
|
||
" <td>Very Good</td>\n",
|
||
" <td>G</td>\n",
|
||
" <td>VVS2</td>\n",
|
||
" <td>63.1</td>\n",
|
||
" <td>57.0</td>\n",
|
||
" <td>907</td>\n",
|
||
" <td>4.32</td>\n",
|
||
" <td>4.30</td>\n",
|
||
" <td>2.72</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>43154 rows × 10 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" carat cut color clarity depth table price x y z\n",
|
||
"id \n",
|
||
"38836 0.40 Very Good F VVS2 62.0 56.0 1049 4.71 4.74 2.93\n",
|
||
"30260 0.40 Very Good E SI1 63.0 57.0 725 4.68 4.71 2.96\n",
|
||
"33169 0.36 Ideal E VS1 61.8 56.0 817 4.55 4.58 2.82\n",
|
||
"1029 0.70 Very Good E VS1 58.4 59.0 2904 5.83 5.91 3.43\n",
|
||
"53809 0.81 Very Good G SI1 60.7 56.0 2733 6.06 6.09 3.69\n",
|
||
"... ... ... ... ... ... ... ... ... ... ...\n",
|
||
"2937 0.77 Good E VS2 63.4 57.0 3291 5.80 5.84 3.69\n",
|
||
"7514 0.90 Good F SI1 61.8 63.0 4241 6.21 6.18 3.83\n",
|
||
"48344 0.56 Ideal H VVS1 62.1 53.8 1961 5.27 5.33 3.29\n",
|
||
"3212 0.70 Premium F VVS1 61.8 60.0 3348 5.67 5.63 3.49\n",
|
||
"35654 0.31 Very Good G VVS2 63.1 57.0 907 4.32 4.30 2.72\n",
|
||
"\n",
|
||
"[43154 rows x 10 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"id\n",
|
||
"38836 0\n",
|
||
"30260 0\n",
|
||
"33169 0\n",
|
||
"1029 0\n",
|
||
"53809 1\n",
|
||
" ..\n",
|
||
"2937 0\n",
|
||
"7514 1\n",
|
||
"48344 0\n",
|
||
"3212 0\n",
|
||
"35654 0\n",
|
||
"Name: above_average_carat, Length: 43154, dtype: int64"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>carat</th>\n",
|
||
" <th>cut</th>\n",
|
||
" <th>color</th>\n",
|
||
" <th>clarity</th>\n",
|
||
" <th>depth</th>\n",
|
||
" <th>table</th>\n",
|
||
" <th>price</th>\n",
|
||
" <th>x</th>\n",
|
||
" <th>y</th>\n",
|
||
" <th>z</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>id</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>32452</th>\n",
|
||
" <td>0.39</td>\n",
|
||
" <td>Very Good</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>VS2</td>\n",
|
||
" <td>60.9</td>\n",
|
||
" <td>58.0</td>\n",
|
||
" <td>793</td>\n",
|
||
" <td>4.72</td>\n",
|
||
" <td>4.77</td>\n",
|
||
" <td>2.89</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2432</th>\n",
|
||
" <td>0.72</td>\n",
|
||
" <td>Very Good</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>SI1</td>\n",
|
||
" <td>63.3</td>\n",
|
||
" <td>56.0</td>\n",
|
||
" <td>3183</td>\n",
|
||
" <td>5.67</td>\n",
|
||
" <td>5.71</td>\n",
|
||
" <td>3.60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>16456</th>\n",
|
||
" <td>1.21</td>\n",
|
||
" <td>Ideal</td>\n",
|
||
" <td>H</td>\n",
|
||
" <td>SI1</td>\n",
|
||
" <td>62.1</td>\n",
|
||
" <td>59.0</td>\n",
|
||
" <td>6573</td>\n",
|
||
" <td>6.81</td>\n",
|
||
" <td>6.75</td>\n",
|
||
" <td>4.21</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>46045</th>\n",
|
||
" <td>0.56</td>\n",
|
||
" <td>Ideal</td>\n",
|
||
" <td>D</td>\n",
|
||
" <td>SI1</td>\n",
|
||
" <td>62.5</td>\n",
|
||
" <td>56.0</td>\n",
|
||
" <td>1729</td>\n",
|
||
" <td>5.28</td>\n",
|
||
" <td>5.24</td>\n",
|
||
" <td>3.29</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>11115</th>\n",
|
||
" <td>1.00</td>\n",
|
||
" <td>Good</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>SI1</td>\n",
|
||
" <td>62.4</td>\n",
|
||
" <td>59.0</td>\n",
|
||
" <td>4936</td>\n",
|
||
" <td>6.35</td>\n",
|
||
" <td>6.40</td>\n",
|
||
" <td>3.98</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>40250</th>\n",
|
||
" <td>0.50</td>\n",
|
||
" <td>Premium</td>\n",
|
||
" <td>F</td>\n",
|
||
" <td>SI1</td>\n",
|
||
" <td>59.6</td>\n",
|
||
" <td>61.0</td>\n",
|
||
" <td>1125</td>\n",
|
||
" <td>5.15</td>\n",
|
||
" <td>5.12</td>\n",
|
||
" <td>3.06</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3308</th>\n",
|
||
" <td>0.73</td>\n",
|
||
" <td>Ideal</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>VS1</td>\n",
|
||
" <td>62.3</td>\n",
|
||
" <td>56.0</td>\n",
|
||
" <td>3370</td>\n",
|
||
" <td>5.75</td>\n",
|
||
" <td>5.80</td>\n",
|
||
" <td>3.60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>7894</th>\n",
|
||
" <td>1.12</td>\n",
|
||
" <td>Very Good</td>\n",
|
||
" <td>I</td>\n",
|
||
" <td>SI1</td>\n",
|
||
" <td>60.6</td>\n",
|
||
" <td>60.0</td>\n",
|
||
" <td>4312</td>\n",
|
||
" <td>6.73</td>\n",
|
||
" <td>6.77</td>\n",
|
||
" <td>4.09</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>21368</th>\n",
|
||
" <td>0.36</td>\n",
|
||
" <td>Ideal</td>\n",
|
||
" <td>D</td>\n",
|
||
" <td>SI1</td>\n",
|
||
" <td>62.2</td>\n",
|
||
" <td>53.0</td>\n",
|
||
" <td>626</td>\n",
|
||
" <td>4.57</td>\n",
|
||
" <td>4.59</td>\n",
|
||
" <td>2.85</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>46144</th>\n",
|
||
" <td>0.50</td>\n",
|
||
" <td>Premium</td>\n",
|
||
" <td>E</td>\n",
|
||
" <td>VS2</td>\n",
|
||
" <td>61.3</td>\n",
|
||
" <td>59.0</td>\n",
|
||
" <td>1746</td>\n",
|
||
" <td>5.10</td>\n",
|
||
" <td>5.05</td>\n",
|
||
" <td>3.11</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>10789 rows × 10 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" carat cut color clarity depth table price x y z\n",
|
||
"id \n",
|
||
"32452 0.39 Very Good E VS2 60.9 58.0 793 4.72 4.77 2.89\n",
|
||
"2432 0.72 Very Good E SI1 63.3 56.0 3183 5.67 5.71 3.60\n",
|
||
"16456 1.21 Ideal H SI1 62.1 59.0 6573 6.81 6.75 4.21\n",
|
||
"46045 0.56 Ideal D SI1 62.5 56.0 1729 5.28 5.24 3.29\n",
|
||
"11115 1.00 Good E SI1 62.4 59.0 4936 6.35 6.40 3.98\n",
|
||
"... ... ... ... ... ... ... ... ... ... ...\n",
|
||
"40250 0.50 Premium F SI1 59.6 61.0 1125 5.15 5.12 3.06\n",
|
||
"3308 0.73 Ideal E VS1 62.3 56.0 3370 5.75 5.80 3.60\n",
|
||
"7894 1.12 Very Good I SI1 60.6 60.0 4312 6.73 6.77 4.09\n",
|
||
"21368 0.36 Ideal D SI1 62.2 53.0 626 4.57 4.59 2.85\n",
|
||
"46144 0.50 Premium E VS2 61.3 59.0 1746 5.10 5.05 3.11\n",
|
||
"\n",
|
||
"[10789 rows x 10 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"id\n",
|
||
"32452 0\n",
|
||
"2432 0\n",
|
||
"16456 1\n",
|
||
"46045 0\n",
|
||
"11115 1\n",
|
||
" ..\n",
|
||
"40250 0\n",
|
||
"3308 0\n",
|
||
"7894 1\n",
|
||
"21368 0\n",
|
||
"46144 0\n",
|
||
"Name: above_average_carat, Length: 10789, dtype: int64"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from typing import Tuple\n",
|
||
"import pandas as pd\n",
|
||
"from pandas import DataFrame\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"\n",
|
||
"def split_stratified_into_train_val_test(\n",
|
||
" df_input,\n",
|
||
" stratify_colname=\"y\",\n",
|
||
" frac_train=0.6,\n",
|
||
" frac_val=0.15,\n",
|
||
" frac_test=0.25,\n",
|
||
" random_state=None,\n",
|
||
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
||
" \"\"\"\n",
|
||
" Splits a Pandas dataframe into three subsets (train, val, and test)\n",
|
||
" following fractional ratios provided by the user, where each subset is\n",
|
||
" stratified by the values in a specific column (that is, each subset has\n",
|
||
" the same relative frequency of the values in the column). It performs this\n",
|
||
" splitting by running train_test_split() twice.\n",
|
||
" Parameters\n",
|
||
" ----------\n",
|
||
" df_input : Pandas dataframe\n",
|
||
" Input dataframe to be split.\n",
|
||
" stratify_colname : str\n",
|
||
" The name of the column that will be used for stratification. Usually\n",
|
||
" this column would be for the label.\n",
|
||
" frac_train : float\n",
|
||
" frac_val : float\n",
|
||
" frac_test : float\n",
|
||
" The ratios with which the dataframe will be split into train, val, and\n",
|
||
" test data. The values should be expressed as float fractions and should\n",
|
||
" sum to 1.0.\n",
|
||
" random_state : int, None, or RandomStateInstance\n",
|
||
" Value to be passed to train_test_split().\n",
|
||
" Returns\n",
|
||
" -------\n",
|
||
" df_train, df_val, df_test :\n",
|
||
" Dataframes containing the three splits.\n",
|
||
" \"\"\"\n",
|
||
" if frac_train + frac_val + frac_test != 1.0:\n",
|
||
" raise ValueError(\n",
|
||
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
|
||
" % (frac_train, frac_val, frac_test)\n",
|
||
" )\n",
|
||
" if stratify_colname not in df_input.columns:\n",
|
||
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
|
||
" X = df_input # содержит все столбцы.\n",
|
||
" y = df_input[\n",
|
||
" [stratify_colname]\n",
|
||
" ] # содержит столбец для стратификации.\n",
|
||
" X = df.drop(['above_average_carat'], axis=1)\n",
|
||
" y = df['above_average_carat']\n",
|
||
" # Первичное разбиение на обучающую и временную выборки.\n",
|
||
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
|
||
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
|
||
" )\n",
|
||
" if frac_val <= 0:\n",
|
||
" assert len(df_input) == len(df_train) + len(df_temp)\n",
|
||
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
|
||
" # Вторичное разбиение на валидационную и тестовую выборки.\n",
|
||
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
|
||
" df_val, df_test, y_val, y_test = train_test_split(\n",
|
||
" df_temp,\n",
|
||
" y_temp,\n",
|
||
" stratify=y_temp,\n",
|
||
" test_size=relative_frac_test,\n",
|
||
" random_state=random_state,\n",
|
||
" )\n",
|
||
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
|
||
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
|
||
"\n",
|
||
"#разбиение на 80% обучающей выборки и 20% тестовой выборки\n",
|
||
"#Стратификация выполняется по столбцу \"above_average_carat\"\n",
|
||
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
|
||
" df, stratify_colname=\"above_average_carat\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
|
||
")\n",
|
||
"\n",
|
||
"display(\"X_train\", X_train)\n",
|
||
"display(\"y_train\", y_train)\n",
|
||
"\n",
|
||
"display(\"X_test\", X_test)\n",
|
||
"display(\"y_test\", y_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Формирование конвейера для классификации данных\n",
|
||
"\n",
|
||
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
|
||
"\n",
|
||
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
|
||
"\n",
|
||
"features_preprocessing -- трансформер для предобработки признаков\n",
|
||
"\n",
|
||
"features_engineering -- трансформер для конструирования признаков\n",
|
||
"\n",
|
||
"drop_columns -- трансформер для удаления колонок\n",
|
||
"\n",
|
||
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 40,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"from sklearn.base import BaseEstimator, TransformerMixin\n",
|
||
"from sklearn.compose import ColumnTransformer\n",
|
||
"from sklearn.discriminant_analysis import StandardScaler\n",
|
||
"from sklearn.impute import SimpleImputer\n",
|
||
"from sklearn.pipeline import Pipeline\n",
|
||
"from sklearn.preprocessing import OneHotEncoder\n",
|
||
"\n",
|
||
"class DaimondFeatures(BaseEstimator, TransformerMixin):\n",
|
||
" def __init__(self):\n",
|
||
" pass\n",
|
||
" def fit(self, X, y=None):\n",
|
||
" return self\n",
|
||
" #добавляем новый столбец \"Length_to_Width_Ratio\" - отношение длины к ширине (столбцы x и y)\n",
|
||
" def transform(self, X, y=None):\n",
|
||
" X[\"Length_to_Width_Ratio\"] = X[\"x\"] / X[\"y\"]\n",
|
||
" return X\n",
|
||
" #добавляем имя нового столбца к входным именам\n",
|
||
" def get_feature_names_out(self, features_in):\n",
|
||
" return np.append(features_in, [\"Length_to_Width_Ratio\"], axis=0)\n",
|
||
" \n",
|
||
"#Список столбцов, которые будут удалены из данных\n",
|
||
"columns_to_drop = []\n",
|
||
"#Список числовых столбцов\n",
|
||
"num_columns = [\"carat\", \"depth\", \"table\", \"x\", \"y\", \"z\"]\n",
|
||
"#Список категориальных столбцов\n",
|
||
"cat_columns = [\"cut\", \"color\", \"clarity\"]\n",
|
||
"\n",
|
||
"#Предобработка числовых столбцов\n",
|
||
"#Используется для заполнения пропущенных значений в числовых столбцах медианой\n",
|
||
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
||
"#Стандартизирует числовые столбцы, приводя их к нулевому среднему и единичному стандартному отклонению\n",
|
||
"num_scaler = StandardScaler()\n",
|
||
"#Создается конвейер для последовательного применения SimpleImputer и StandardScaler\n",
|
||
"preprocessing_num = Pipeline(\n",
|
||
" [\n",
|
||
" (\"imputer\", num_imputer),\n",
|
||
" (\"scaler\", num_scaler),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"#Предобработка категориальных столбцов\n",
|
||
"#Заполняет пропущенные значения в категориальных столбцов значением \"unknown\"\n",
|
||
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
||
"#Кодирует категориальные переменные с помощью one-hot кодирования\n",
|
||
"#handle_unknown=ignore позволяет игнорировать неизвестные категории\n",
|
||
"#drop=first исключает первую категорию для предотвращения мультиколлинеарности\n",
|
||
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
||
"preprocessing_cat = Pipeline(\n",
|
||
" [\n",
|
||
" (\"imputer\", cat_imputer),\n",
|
||
" (\"encoder\", cat_encoder),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"#Объединение этапов предобработки\n",
|
||
"#ColumnTransformer Позволяет применять разные преобразования к различным колонкам\n",
|
||
"features_preprocessing = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" #числовые и категориальные признаки обрабатываются разными конвейерами\n",
|
||
" transformers=[\n",
|
||
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
|
||
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
|
||
" ],\n",
|
||
" #все остальные столбцы, которые не указаны в трансформерах, будут переданы без изменений\n",
|
||
" remainder=\"passthrough\"\n",
|
||
")\n",
|
||
"#конструирование признаков\n",
|
||
"#создается еще один ColumnTransformer, который применяет класс DaimondFeatures к столбцам x и y \n",
|
||
"#Все остальные столбцы будут переданы без изменений\n",
|
||
"features_engineering = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"add_features\", DaimondFeatures(), [\"x\", \"y\"]),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\",\n",
|
||
")\n",
|
||
"\n",
|
||
"#удаление столбцов, в данном случае ничего не будет удалено\n",
|
||
"drop_columns = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\",\n",
|
||
")\n",
|
||
"\n",
|
||
"#Постобработка столбцов\n",
|
||
"#Создается еще один ColumnTransformer который обрабатывает столбец Cabin_type с помощью prepocessing_cat\n",
|
||
"features_postprocessing = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"prepocessing_cat\", preprocessing_cat, [\"Cabin_type\"]),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\",\n",
|
||
")\n",
|
||
"\n",
|
||
"#Финальный конвейер\n",
|
||
"pipeline_end = Pipeline(\n",
|
||
" [\n",
|
||
" (\"features_preprocessing\", features_preprocessing),\n",
|
||
" (\"features_engineering\", features_engineering),\n",
|
||
" (\"drop_columns\", drop_columns),\n",
|
||
" ]\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Демонстрация работы конвейера для предобработки данных при классификации"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 41,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>x</th>\n",
|
||
" <th>y</th>\n",
|
||
" <th>Length_to_Width_Ratio</th>\n",
|
||
" <th>carat</th>\n",
|
||
" <th>depth</th>\n",
|
||
" <th>table</th>\n",
|
||
" <th>z</th>\n",
|
||
" <th>cut_Good</th>\n",
|
||
" <th>cut_Ideal</th>\n",
|
||
" <th>cut_Premium</th>\n",
|
||
" <th>...</th>\n",
|
||
" <th>color_I</th>\n",
|
||
" <th>color_J</th>\n",
|
||
" <th>clarity_IF</th>\n",
|
||
" <th>clarity_SI1</th>\n",
|
||
" <th>clarity_SI2</th>\n",
|
||
" <th>clarity_VS1</th>\n",
|
||
" <th>clarity_VS2</th>\n",
|
||
" <th>clarity_VVS1</th>\n",
|
||
" <th>clarity_VVS2</th>\n",
|
||
" <th>price</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>id</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>38836</th>\n",
|
||
" <td>-0.907744</td>\n",
|
||
" <td>-0.863476</td>\n",
|
||
" <td>1.051267</td>\n",
|
||
" <td>-0.837490</td>\n",
|
||
" <td>0.176170</td>\n",
|
||
" <td>-0.648004</td>\n",
|
||
" <td>-0.857040</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>1049</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>30260</th>\n",
|
||
" <td>-0.934483</td>\n",
|
||
" <td>-0.889579</td>\n",
|
||
" <td>1.050478</td>\n",
|
||
" <td>-0.837490</td>\n",
|
||
" <td>0.876071</td>\n",
|
||
" <td>-0.201125</td>\n",
|
||
" <td>-0.814688</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>725</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>33169</th>\n",
|
||
" <td>-1.050350</td>\n",
|
||
" <td>-1.002691</td>\n",
|
||
" <td>1.047532</td>\n",
|
||
" <td>-0.921885</td>\n",
|
||
" <td>0.036190</td>\n",
|
||
" <td>-0.648004</td>\n",
|
||
" <td>-1.012333</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>817</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1029</th>\n",
|
||
" <td>0.090496</td>\n",
|
||
" <td>0.154530</td>\n",
|
||
" <td>0.585622</td>\n",
|
||
" <td>-0.204531</td>\n",
|
||
" <td>-2.343471</td>\n",
|
||
" <td>0.692631</td>\n",
|
||
" <td>-0.151165</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>2904</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>53809</th>\n",
|
||
" <td>0.295492</td>\n",
|
||
" <td>0.311147</td>\n",
|
||
" <td>0.949688</td>\n",
|
||
" <td>0.027554</td>\n",
|
||
" <td>-0.733700</td>\n",
|
||
" <td>-0.648004</td>\n",
|
||
" <td>0.215890</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>2733</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2937</th>\n",
|
||
" <td>0.063758</td>\n",
|
||
" <td>0.093624</td>\n",
|
||
" <td>0.680999</td>\n",
|
||
" <td>-0.056841</td>\n",
|
||
" <td>1.156031</td>\n",
|
||
" <td>-0.201125</td>\n",
|
||
" <td>0.215890</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>3291</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>7514</th>\n",
|
||
" <td>0.429185</td>\n",
|
||
" <td>0.389455</td>\n",
|
||
" <td>1.102015</td>\n",
|
||
" <td>0.217442</td>\n",
|
||
" <td>0.036190</td>\n",
|
||
" <td>2.480145</td>\n",
|
||
" <td>0.413535</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>4241</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>48344</th>\n",
|
||
" <td>-0.408624</td>\n",
|
||
" <td>-0.350123</td>\n",
|
||
" <td>1.167088</td>\n",
|
||
" <td>-0.499912</td>\n",
|
||
" <td>0.246160</td>\n",
|
||
" <td>-1.631136</td>\n",
|
||
" <td>-0.348810</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1961</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3212</th>\n",
|
||
" <td>-0.052109</td>\n",
|
||
" <td>-0.089095</td>\n",
|
||
" <td>0.584874</td>\n",
|
||
" <td>-0.204531</td>\n",
|
||
" <td>0.036190</td>\n",
|
||
" <td>1.139510</td>\n",
|
||
" <td>-0.066460</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>3348</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>35654</th>\n",
|
||
" <td>-1.255346</td>\n",
|
||
" <td>-1.246316</td>\n",
|
||
" <td>1.007245</td>\n",
|
||
" <td>-1.027378</td>\n",
|
||
" <td>0.946061</td>\n",
|
||
" <td>-0.201125</td>\n",
|
||
" <td>-1.153508</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>907</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>43154 rows × 25 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" x y Length_to_Width_Ratio carat depth \\\n",
|
||
"id \n",
|
||
"38836 -0.907744 -0.863476 1.051267 -0.837490 0.176170 \n",
|
||
"30260 -0.934483 -0.889579 1.050478 -0.837490 0.876071 \n",
|
||
"33169 -1.050350 -1.002691 1.047532 -0.921885 0.036190 \n",
|
||
"1029 0.090496 0.154530 0.585622 -0.204531 -2.343471 \n",
|
||
"53809 0.295492 0.311147 0.949688 0.027554 -0.733700 \n",
|
||
"... ... ... ... ... ... \n",
|
||
"2937 0.063758 0.093624 0.680999 -0.056841 1.156031 \n",
|
||
"7514 0.429185 0.389455 1.102015 0.217442 0.036190 \n",
|
||
"48344 -0.408624 -0.350123 1.167088 -0.499912 0.246160 \n",
|
||
"3212 -0.052109 -0.089095 0.584874 -0.204531 0.036190 \n",
|
||
"35654 -1.255346 -1.246316 1.007245 -1.027378 0.946061 \n",
|
||
"\n",
|
||
" table z cut_Good cut_Ideal cut_Premium ... color_I \\\n",
|
||
"id ... \n",
|
||
"38836 -0.648004 -0.857040 0.0 0.0 0.0 ... 0.0 \n",
|
||
"30260 -0.201125 -0.814688 0.0 0.0 0.0 ... 0.0 \n",
|
||
"33169 -0.648004 -1.012333 0.0 1.0 0.0 ... 0.0 \n",
|
||
"1029 0.692631 -0.151165 0.0 0.0 0.0 ... 0.0 \n",
|
||
"53809 -0.648004 0.215890 0.0 0.0 0.0 ... 0.0 \n",
|
||
"... ... ... ... ... ... ... ... \n",
|
||
"2937 -0.201125 0.215890 1.0 0.0 0.0 ... 0.0 \n",
|
||
"7514 2.480145 0.413535 1.0 0.0 0.0 ... 0.0 \n",
|
||
"48344 -1.631136 -0.348810 0.0 1.0 0.0 ... 0.0 \n",
|
||
"3212 1.139510 -0.066460 0.0 0.0 1.0 ... 0.0 \n",
|
||
"35654 -0.201125 -1.153508 0.0 0.0 0.0 ... 0.0 \n",
|
||
"\n",
|
||
" color_J clarity_IF clarity_SI1 clarity_SI2 clarity_VS1 \\\n",
|
||
"id \n",
|
||
"38836 0.0 0.0 0.0 0.0 0.0 \n",
|
||
"30260 0.0 0.0 1.0 0.0 0.0 \n",
|
||
"33169 0.0 0.0 0.0 0.0 1.0 \n",
|
||
"1029 0.0 0.0 0.0 0.0 1.0 \n",
|
||
"53809 0.0 0.0 1.0 0.0 0.0 \n",
|
||
"... ... ... ... ... ... \n",
|
||
"2937 0.0 0.0 0.0 0.0 0.0 \n",
|
||
"7514 0.0 0.0 1.0 0.0 0.0 \n",
|
||
"48344 0.0 0.0 0.0 0.0 0.0 \n",
|
||
"3212 0.0 0.0 0.0 0.0 0.0 \n",
|
||
"35654 0.0 0.0 0.0 0.0 0.0 \n",
|
||
"\n",
|
||
" clarity_VS2 clarity_VVS1 clarity_VVS2 price \n",
|
||
"id \n",
|
||
"38836 0.0 0.0 1.0 1049 \n",
|
||
"30260 0.0 0.0 0.0 725 \n",
|
||
"33169 0.0 0.0 0.0 817 \n",
|
||
"1029 0.0 0.0 0.0 2904 \n",
|
||
"53809 0.0 0.0 0.0 2733 \n",
|
||
"... ... ... ... ... \n",
|
||
"2937 1.0 0.0 0.0 3291 \n",
|
||
"7514 0.0 0.0 0.0 4241 \n",
|
||
"48344 0.0 1.0 0.0 1961 \n",
|
||
"3212 0.0 1.0 0.0 3348 \n",
|
||
"35654 0.0 0.0 1.0 907 \n",
|
||
"\n",
|
||
"[43154 rows x 25 columns]"
|
||
]
|
||
},
|
||
"execution_count": 41,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"#Обучает все трансформеры в конвейере на данных обучающей выборки\n",
|
||
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
|
||
"preprocessed_df = pd.DataFrame(\n",
|
||
" preprocessing_result,\n",
|
||
" #возвращает имена всех столбцов, которые были созданы в результате всех этапов предобработки в конвейере\n",
|
||
" columns=pipeline_end.get_feature_names_out(),\n",
|
||
")\n",
|
||
"\n",
|
||
"preprocessed_df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Формирование набора моделей для классификации"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 42,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
|
||
"\n",
|
||
"#классификационные модели\n",
|
||
"class_models = {\n",
|
||
" #Логистическая регрессия\n",
|
||
" # от 0 до 1, принадлежит ли объект к классу\n",
|
||
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
|
||
" #гребневая регрессия\n",
|
||
" #Логическая, но с регуляризацией L2 (модель не так точно запоминает данные) и сбалансированными весами классов\n",
|
||
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
|
||
" #дерево решений с максимальной глубиной 7\n",
|
||
" #Деления данных на условия с помощью построения дерева\n",
|
||
" \"decision_tree\": {\n",
|
||
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
|
||
" },\n",
|
||
" #K-ближайших соседей с количеством соседей, равным 7\n",
|
||
" #Определяет ближайших объектов и находит и класс\n",
|
||
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
|
||
" #Наивный байесовский классификатор - Вероятности для классификации\n",
|
||
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
|
||
" #Градиентный бустинг с 210 деревьями\n",
|
||
" #Постепенно улучшает предсказания с помощью слабых моделей\n",
|
||
" \"gradient_boosting\": {\n",
|
||
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
|
||
" },\n",
|
||
" #Случайный лес с максимальной глубиной 11 и сбалансированными весами классов\n",
|
||
" \"random_forest\": {\n",
|
||
" \"model\": ensemble.RandomForestClassifier(\n",
|
||
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
|
||
" )\n",
|
||
" },\n",
|
||
" #Многослойный персептрон (нейронная сеть)с одним скрытым слоем\n",
|
||
" \"mlp\": {\n",
|
||
" \"model\": neural_network.MLPClassifier(\n",
|
||
" #содержажит 7 нейронов\n",
|
||
" hidden_layer_sizes=(7,),\n",
|
||
" #максимальное количеством итераций 500\n",
|
||
" max_iter=500,\n",
|
||
" #включение ранней остановки для предотвращения переобучения\n",
|
||
" early_stopping=True,\n",
|
||
" random_state=random_state,\n",
|
||
" )\n",
|
||
" },\n",
|
||
"}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Обучение моделей на обучающем наборе данных и оценка на тестовом"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 43,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Model: logistic\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"d:\\3 kurs\\МИИ\\1 лаб\\mai-main\\.venv\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
||
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
||
"\n",
|
||
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
||
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
||
"Please also refer to the documentation for alternative solver options:\n",
|
||
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
||
" n_iter_i = _check_optimize_result(\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Model: ridge\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"d:\\3 kurs\\МИИ\\1 лаб\\mai-main\\.venv\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
||
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
||
"\n",
|
||
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
||
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
||
"Please also refer to the documentation for alternative solver options:\n",
|
||
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
||
" n_iter_i = _check_optimize_result(\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Model: decision_tree\n",
|
||
"Model: knn\n",
|
||
"Model: naive_bayes\n",
|
||
"Model: gradient_boosting\n",
|
||
"Model: random_forest\n",
|
||
"Model: mlp\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"from sklearn import metrics\n",
|
||
"\n",
|
||
"#проходим по всем моделям\n",
|
||
"for model_name in class_models.keys():\n",
|
||
" #выводим названия\n",
|
||
" print(f\"Model: {model_name}\")\n",
|
||
" model = class_models[model_name][\"model\"]\n",
|
||
"\n",
|
||
" #Получение модели и создание конвейера\n",
|
||
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
|
||
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
|
||
"\n",
|
||
" #Предсказания на обучающей и тестовой выборке\n",
|
||
" y_train_predict = model_pipeline.predict(X_train)\n",
|
||
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
|
||
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
|
||
"\n",
|
||
" #Сохранение результатов в словаре моделей\n",
|
||
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
||
" class_models[model_name][\"probs\"] = y_test_probs\n",
|
||
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
||
"\n",
|
||
" #Оценка производительности модели\n",
|
||
" #точность\n",
|
||
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" #полнота\n",
|
||
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" #аккуратность (верность) - Доля правильных предсказаний среди всех\n",
|
||
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" #Площадь под кривой ROC, которая показывает качество классификации\n",
|
||
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
|
||
" y_test, y_test_probs\n",
|
||
" )\n",
|
||
" #Гармоническое среднее между точностью и полнотой\n",
|
||
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
|
||
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
|
||
" #Коэффициент корреляции Мэтьюса, который учитывает все возможные результаты классификации\n",
|
||
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" #Мера согласия между двумя классификаторами\n",
|
||
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" #Матрица ошибок, показывающая количество истинных и ложных положительных и отрицательных предсказаний\n",
|
||
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Сводная таблица оценок качества для использованных моделей классификации"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Матрица неточностей"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 44,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 1200x1000 with 16 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"\n",
|
||
"#Функция, создающая фигуру и набор подграфиков (axes)\n",
|
||
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
|
||
"for index, key in enumerate(class_models.keys()):\n",
|
||
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
|
||
" disp = ConfusionMatrixDisplay(\n",
|
||
" confusion_matrix=c_matrix, display_labels=[\"Less\", \"More\"]\n",
|
||
" ).plot(ax=ax.flat[index])\n",
|
||
" disp.ax_.set_title(key)\n",
|
||
"\n",
|
||
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Точность, полнота, верность (аккуратность), F-мера"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 46,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_b8c4d_row0_col0, #T_b8c4d_row0_col1, #T_b8c4d_row0_col2, #T_b8c4d_row0_col3, #T_b8c4d_row1_col0, #T_b8c4d_row1_col1, #T_b8c4d_row1_col2, #T_b8c4d_row1_col3, #T_b8c4d_row2_col0, #T_b8c4d_row2_col1, #T_b8c4d_row2_col2, #T_b8c4d_row2_col3 {\n",
|
||
" background-color: #a8db34;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row0_col4, #T_b8c4d_row0_col5, #T_b8c4d_row0_col6, #T_b8c4d_row0_col7, #T_b8c4d_row1_col4, #T_b8c4d_row1_col5, #T_b8c4d_row1_col6, #T_b8c4d_row1_col7, #T_b8c4d_row2_col4, #T_b8c4d_row2_col5, #T_b8c4d_row2_col6, #T_b8c4d_row2_col7 {\n",
|
||
" background-color: #da5a6a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row3_col0 {\n",
|
||
" background-color: #7fd34e;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row3_col1, #T_b8c4d_row5_col0 {\n",
|
||
" background-color: #8ed645;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row3_col2 {\n",
|
||
" background-color: #90d743;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row3_col3, #T_b8c4d_row4_col1 {\n",
|
||
" background-color: #93d741;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row3_col4, #T_b8c4d_row3_col6 {\n",
|
||
" background-color: #ca457a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row3_col5, #T_b8c4d_row3_col7 {\n",
|
||
" background-color: #cf4c74;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row4_col0, #T_b8c4d_row4_col3 {\n",
|
||
" background-color: #86d549;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row4_col2 {\n",
|
||
" background-color: #6ece58;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row4_col4 {\n",
|
||
" background-color: #c43e7f;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row4_col5, #T_b8c4d_row4_col7 {\n",
|
||
" background-color: #ce4b75;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row4_col6 {\n",
|
||
" background-color: #c5407e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row5_col1, #T_b8c4d_row6_col3 {\n",
|
||
" background-color: #9bd93c;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row5_col2 {\n",
|
||
" background-color: #5ac864;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row5_col3 {\n",
|
||
" background-color: #73d056;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row5_col4, #T_b8c4d_row5_col6 {\n",
|
||
" background-color: #c13b82;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row5_col5, #T_b8c4d_row5_col7 {\n",
|
||
" background-color: #cc4778;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row6_col0, #T_b8c4d_row6_col1, #T_b8c4d_row7_col2, #T_b8c4d_row7_col3 {\n",
|
||
" background-color: #26818e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row6_col2 {\n",
|
||
" background-color: #9dd93b;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row6_col4 {\n",
|
||
" background-color: #6300a7;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row6_col5 {\n",
|
||
" background-color: #7100a8;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row6_col6 {\n",
|
||
" background-color: #6700a8;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row6_col7 {\n",
|
||
" background-color: #7501a8;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row7_col0 {\n",
|
||
" background-color: #2eb37c;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row7_col1 {\n",
|
||
" background-color: #25ac82;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_b8c4d_row7_col4, #T_b8c4d_row7_col5, #T_b8c4d_row7_col6, #T_b8c4d_row7_col7 {\n",
|
||
" background-color: #4e02a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_b8c4d\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_b8c4d_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
||
" <th id=\"T_b8c4d_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
||
" <th id=\"T_b8c4d_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
||
" <th id=\"T_b8c4d_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
||
" <th id=\"T_b8c4d_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
||
" <th id=\"T_b8c4d_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_b8c4d_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
||
" <th id=\"T_b8c4d_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_b8c4d_level0_row0\" class=\"row_heading level0 row0\" >decision_tree</th>\n",
|
||
" <td id=\"T_b8c4d_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_b8c4d_level0_row1\" class=\"row_heading level0 row1\" >gradient_boosting</th>\n",
|
||
" <td id=\"T_b8c4d_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_b8c4d_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
|
||
" <td id=\"T_b8c4d_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_b8c4d_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_b8c4d_level0_row3\" class=\"row_heading level0 row3\" >mlp</th>\n",
|
||
" <td id=\"T_b8c4d_row3_col0\" class=\"data row3 col0\" >0.992471</td>\n",
|
||
" <td id=\"T_b8c4d_row3_col1\" class=\"data row3 col1\" >0.994967</td>\n",
|
||
" <td id=\"T_b8c4d_row3_col2\" class=\"data row3 col2\" >0.996822</td>\n",
|
||
" <td id=\"T_b8c4d_row3_col3\" class=\"data row3 col3\" >0.996494</td>\n",
|
||
" <td id=\"T_b8c4d_row3_col4\" class=\"data row3 col4\" >0.995458</td>\n",
|
||
" <td id=\"T_b8c4d_row3_col5\" class=\"data row3 col5\" >0.996385</td>\n",
|
||
" <td id=\"T_b8c4d_row3_col6\" class=\"data row3 col6\" >0.994642</td>\n",
|
||
" <td id=\"T_b8c4d_row3_col7\" class=\"data row3 col7\" >0.995730</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_b8c4d_level0_row4\" class=\"row_heading level0 row4\" >ridge</th>\n",
|
||
" <td id=\"T_b8c4d_row4_col0\" class=\"data row4 col0\" >0.993907</td>\n",
|
||
" <td id=\"T_b8c4d_row4_col1\" class=\"data row4 col1\" >0.996049</td>\n",
|
||
" <td id=\"T_b8c4d_row4_col2\" class=\"data row4 col2\" >0.992164</td>\n",
|
||
" <td id=\"T_b8c4d_row4_col3\" class=\"data row4 col3\" >0.994521</td>\n",
|
||
" <td id=\"T_b8c4d_row4_col4\" class=\"data row4 col4\" >0.994114</td>\n",
|
||
" <td id=\"T_b8c4d_row4_col5\" class=\"data row4 col5\" >0.996014</td>\n",
|
||
" <td id=\"T_b8c4d_row4_col6\" class=\"data row4 col6\" >0.993035</td>\n",
|
||
" <td id=\"T_b8c4d_row4_col7\" class=\"data row4 col7\" >0.995285</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_b8c4d_level0_row5\" class=\"row_heading level0 row5\" >logistic</th>\n",
|
||
" <td id=\"T_b8c4d_row5_col0\" class=\"data row5 col0\" >0.995093</td>\n",
|
||
" <td id=\"T_b8c4d_row5_col1\" class=\"data row5 col1\" >0.997574</td>\n",
|
||
" <td id=\"T_b8c4d_row5_col2\" class=\"data row5 col2\" >0.989041</td>\n",
|
||
" <td id=\"T_b8c4d_row5_col3\" class=\"data row5 col3\" >0.991234</td>\n",
|
||
" <td id=\"T_b8c4d_row5_col4\" class=\"data row5 col4\" >0.993303</td>\n",
|
||
" <td id=\"T_b8c4d_row5_col5\" class=\"data row5 col5\" >0.995273</td>\n",
|
||
" <td id=\"T_b8c4d_row5_col6\" class=\"data row5 col6\" >0.992058</td>\n",
|
||
" <td id=\"T_b8c4d_row5_col7\" class=\"data row5 col7\" >0.994394</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_b8c4d_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
|
||
" <td id=\"T_b8c4d_row6_col0\" class=\"data row6 col0\" >0.946797</td>\n",
|
||
" <td id=\"T_b8c4d_row6_col1\" class=\"data row6 col1\" >0.942845</td>\n",
|
||
" <td id=\"T_b8c4d_row6_col2\" class=\"data row6 col2\" >0.998521</td>\n",
|
||
" <td id=\"T_b8c4d_row6_col3\" class=\"data row6 col3\" >0.997808</td>\n",
|
||
" <td id=\"T_b8c4d_row6_col4\" class=\"data row6 col4\" >0.975645</td>\n",
|
||
" <td id=\"T_b8c4d_row6_col5\" class=\"data row6 col5\" >0.973492</td>\n",
|
||
" <td id=\"T_b8c4d_row6_col6\" class=\"data row6 col6\" >0.971971</td>\n",
|
||
" <td id=\"T_b8c4d_row6_col7\" class=\"data row6 col7\" >0.969549</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_b8c4d_level0_row7\" class=\"row_heading level0 row7\" >knn</th>\n",
|
||
" <td id=\"T_b8c4d_row7_col0\" class=\"data row7 col0\" >0.972527</td>\n",
|
||
" <td id=\"T_b8c4d_row7_col1\" class=\"data row7 col1\" >0.966497</td>\n",
|
||
" <td id=\"T_b8c4d_row7_col2\" class=\"data row7 col2\" >0.962082</td>\n",
|
||
" <td id=\"T_b8c4d_row7_col3\" class=\"data row7 col3\" >0.954635</td>\n",
|
||
" <td id=\"T_b8c4d_row7_col4\" class=\"data row7 col4\" >0.972471</td>\n",
|
||
" <td id=\"T_b8c4d_row7_col5\" class=\"data row7 col5\" >0.966818</td>\n",
|
||
" <td id=\"T_b8c4d_row7_col6\" class=\"data row7 col6\" >0.967276</td>\n",
|
||
" <td id=\"T_b8c4d_row7_col7\" class=\"data row7 col7\" >0.960529</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x1659b03a1e0>"
|
||
]
|
||
},
|
||
"execution_count": 46,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
||
" [\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" \"Accuracy_train\",\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_train\",\n",
|
||
" \"F1_test\",\n",
|
||
" ]\n",
|
||
"]\n",
|
||
"#сортировка по столбцу Accuracy_test в порядке убывания\n",
|
||
"#чтобы увидеть модели с наивысшей точностью на тестовой выборке в начале таблицы\n",
|
||
"class_metrics.sort_values(\n",
|
||
" by=\"Accuracy_test\", ascending=False\n",
|
||
").style.background_gradient( #визуализация\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" #Указывает, какие столбцы будут окрашены\n",
|
||
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_76f05_row0_col0, #T_76f05_row0_col1, #T_76f05_row1_col0, #T_76f05_row1_col1, #T_76f05_row2_col0, #T_76f05_row2_col1, #T_76f05_row3_col0, #T_76f05_row3_col1, #T_76f05_row4_col0, #T_76f05_row4_col1, #T_76f05_row5_col0, #T_76f05_row5_col1 {\n",
|
||
" background-color: #a8db34;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_76f05_row0_col2, #T_76f05_row0_col3, #T_76f05_row0_col4, #T_76f05_row1_col2, #T_76f05_row1_col3, #T_76f05_row1_col4, #T_76f05_row2_col2, #T_76f05_row2_col3, #T_76f05_row2_col4, #T_76f05_row3_col2, #T_76f05_row3_col3, #T_76f05_row3_col4, #T_76f05_row4_col2, #T_76f05_row4_col3, #T_76f05_row4_col4, #T_76f05_row5_col2, #T_76f05_row5_col3, #T_76f05_row5_col4 {\n",
|
||
" background-color: #da5a6a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_76f05_row6_col0, #T_76f05_row6_col1 {\n",
|
||
" background-color: #a2da37;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_76f05_row6_col2 {\n",
|
||
" background-color: #d45270;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_76f05_row6_col3, #T_76f05_row6_col4 {\n",
|
||
" background-color: #d8576b;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_76f05_row7_col0, #T_76f05_row7_col1 {\n",
|
||
" background-color: #26818e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_76f05_row7_col2, #T_76f05_row7_col3, #T_76f05_row7_col4 {\n",
|
||
" background-color: #4e02a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_76f05\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_76f05_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_76f05_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
||
" <th id=\"T_76f05_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
||
" <th id=\"T_76f05_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
||
" <th id=\"T_76f05_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_76f05_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
|
||
" <td id=\"T_76f05_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_76f05_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
|
||
" <td id=\"T_76f05_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_76f05_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
|
||
" <td id=\"T_76f05_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_76f05_level0_row3\" class=\"row_heading level0 row3\" >naive_bayes</th>\n",
|
||
" <td id=\"T_76f05_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_76f05_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
|
||
" <td id=\"T_76f05_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_76f05_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
|
||
" <td id=\"T_76f05_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_76f05_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_76f05_level0_row6\" class=\"row_heading level0 row6\" >mlp</th>\n",
|
||
" <td id=\"T_76f05_row6_col0\" class=\"data row6 col0\" >0.999629</td>\n",
|
||
" <td id=\"T_76f05_row6_col1\" class=\"data row6 col1\" >0.999562</td>\n",
|
||
" <td id=\"T_76f05_row6_col2\" class=\"data row6 col2\" >0.999754</td>\n",
|
||
" <td id=\"T_76f05_row6_col3\" class=\"data row6 col3\" >0.999240</td>\n",
|
||
" <td id=\"T_76f05_row6_col4\" class=\"data row6 col4\" >0.999240</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_76f05_level0_row7\" class=\"row_heading level0 row7\" >knn</th>\n",
|
||
" <td id=\"T_76f05_row7_col0\" class=\"data row7 col0\" >0.980536</td>\n",
|
||
" <td id=\"T_76f05_row7_col1\" class=\"data row7 col1\" >0.976933</td>\n",
|
||
" <td id=\"T_76f05_row7_col2\" class=\"data row7 col2\" >0.995960</td>\n",
|
||
" <td id=\"T_76f05_row7_col3\" class=\"data row7 col3\" >0.960098</td>\n",
|
||
" <td id=\"T_76f05_row7_col4\" class=\"data row7 col4\" >0.960107</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x1659f0e94f0>"
|
||
]
|
||
},
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
||
" [\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_test\",\n",
|
||
" \"ROC_AUC_test\",\n",
|
||
" \"Cohen_kappa_test\",\n",
|
||
" \"MCC_test\",\n",
|
||
" ]\n",
|
||
"]\n",
|
||
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" subset=[\n",
|
||
" \"ROC_AUC_test\",\n",
|
||
" \"MCC_test\",\n",
|
||
" \"Cohen_kappa_test\",\n",
|
||
" ],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'logistic'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"#выводим лучшую модель\n",
|
||
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
|
||
"\n",
|
||
"display(best_model)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Вывод данных с ошибкой предсказания для оценки"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'Error items count: 0'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>carat</th>\n",
|
||
" <th>Predicted</th>\n",
|
||
" <th>cut</th>\n",
|
||
" <th>color</th>\n",
|
||
" <th>clarity</th>\n",
|
||
" <th>depth</th>\n",
|
||
" <th>table</th>\n",
|
||
" <th>price</th>\n",
|
||
" <th>x</th>\n",
|
||
" <th>y</th>\n",
|
||
" <th>z</th>\n",
|
||
" <th>above_average_carat</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>id</th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" <th></th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
"Empty DataFrame\n",
|
||
"Columns: [carat, Predicted, cut, color, clarity, depth, table, price, x, y, z, above_average_carat]\n",
|
||
"Index: []"
|
||
]
|
||
},
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"#предобработка тестовой выборки\n",
|
||
"preprocessing_result = pipeline_end.transform(X_test)\n",
|
||
"preprocessed_df = pd.DataFrame(\n",
|
||
" preprocessing_result,\n",
|
||
" columns=pipeline_end.get_feature_names_out(),\n",
|
||
")\n",
|
||
"#Получение предсказаний для лучшей модели\n",
|
||
"y_pred = class_models[best_model][\"preds\"]\n",
|
||
"\n",
|
||
"#Определение индексов ошибок\n",
|
||
"#Сравнивает истинные значения из y_test (столбец above_average_carat) с предсказанными значениями y_pred\n",
|
||
"#Это дает булев массив, где True означает ошибку\n",
|
||
"error_index = y_test[y_test[\"above_average_carat\"] != y_pred].index.tolist()\n",
|
||
"#выводим кол-во ошибок\n",
|
||
"display(f\"Error items count: {len(error_index)}\")\n",
|
||
"\n",
|
||
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
|
||
"error_df = X_test.loc[error_index].copy()\n",
|
||
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
|
||
"error_df.sort_index()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Пример использования обученной модели (конвейера) для предсказания"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>carat</th>\n",
|
||
" <th>cut</th>\n",
|
||
" <th>color</th>\n",
|
||
" <th>clarity</th>\n",
|
||
" <th>depth</th>\n",
|
||
" <th>table</th>\n",
|
||
" <th>price</th>\n",
|
||
" <th>x</th>\n",
|
||
" <th>y</th>\n",
|
||
" <th>z</th>\n",
|
||
" <th>above_average_carat</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>4500</th>\n",
|
||
" <td>0.9</td>\n",
|
||
" <td>Premium</td>\n",
|
||
" <td>H</td>\n",
|
||
" <td>SI1</td>\n",
|
||
" <td>61.9</td>\n",
|
||
" <td>58.0</td>\n",
|
||
" <td>3629</td>\n",
|
||
" <td>6.2</td>\n",
|
||
" <td>6.15</td>\n",
|
||
" <td>3.82</td>\n",
|
||
" <td>1</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" carat cut color clarity depth table price x y z \\\n",
|
||
"4500 0.9 Premium H SI1 61.9 58.0 3629 6.2 6.15 3.82 \n",
|
||
"\n",
|
||
" above_average_carat \n",
|
||
"4500 1 "
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>x</th>\n",
|
||
" <th>y</th>\n",
|
||
" <th>Length_to_Width_Ratio</th>\n",
|
||
" <th>carat</th>\n",
|
||
" <th>depth</th>\n",
|
||
" <th>table</th>\n",
|
||
" <th>z</th>\n",
|
||
" <th>above_average_carat</th>\n",
|
||
" <th>cut_Good</th>\n",
|
||
" <th>cut_Ideal</th>\n",
|
||
" <th>...</th>\n",
|
||
" <th>color_I</th>\n",
|
||
" <th>color_J</th>\n",
|
||
" <th>clarity_IF</th>\n",
|
||
" <th>clarity_SI1</th>\n",
|
||
" <th>clarity_SI2</th>\n",
|
||
" <th>clarity_VS1</th>\n",
|
||
" <th>clarity_VS2</th>\n",
|
||
" <th>clarity_VVS1</th>\n",
|
||
" <th>clarity_VVS2</th>\n",
|
||
" <th>price</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>4500</th>\n",
|
||
" <td>0.420272</td>\n",
|
||
" <td>0.363352</td>\n",
|
||
" <td>1.156653</td>\n",
|
||
" <td>0.217442</td>\n",
|
||
" <td>0.10618</td>\n",
|
||
" <td>0.245753</td>\n",
|
||
" <td>0.399417</td>\n",
|
||
" <td>1.168162</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>3629.0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>1 rows × 26 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" x y Length_to_Width_Ratio carat depth table \\\n",
|
||
"4500 0.420272 0.363352 1.156653 0.217442 0.10618 0.245753 \n",
|
||
"\n",
|
||
" z above_average_carat cut_Good cut_Ideal ... color_I \\\n",
|
||
"4500 0.399417 1.168162 0.0 0.0 ... 0.0 \n",
|
||
"\n",
|
||
" color_J clarity_IF clarity_SI1 clarity_SI2 clarity_VS1 clarity_VS2 \\\n",
|
||
"4500 0.0 0.0 1.0 0.0 0.0 0.0 \n",
|
||
"\n",
|
||
" clarity_VVS1 clarity_VVS2 price \n",
|
||
"4500 0.0 0.0 3629.0 \n",
|
||
"\n",
|
||
"[1 rows x 26 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'predicted: 1 (proba: [4.39873930e-04 9.99560126e-01])'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'real: 1'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"model = class_models[best_model][\"pipeline\"]\n",
|
||
"\n",
|
||
"example_id = 4500\n",
|
||
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
|
||
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
|
||
"display(test)\n",
|
||
"display(test_preprocessed)\n",
|
||
"#Получение вероятностей предсказания\n",
|
||
"result_proba = model.predict_proba(test)[0]\n",
|
||
"#Получение предсказания\n",
|
||
"result = model.predict(test)[0]\n",
|
||
"#Получение реального значения\n",
|
||
"real = int(y_test.loc[example_id].values[0])\n",
|
||
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
|
||
"display(f\"real: {real}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Подбор гиперпараметров методом поиска по сетке"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 209,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'model__criterion': 'gini',\n",
|
||
" 'model__max_depth': 2,\n",
|
||
" 'model__max_features': 'sqrt',\n",
|
||
" 'model__n_estimators': 20}"
|
||
]
|
||
},
|
||
"execution_count": 209,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.model_selection import GridSearchCV\n",
|
||
"\n",
|
||
"#используем модель случайного леса\n",
|
||
"optimized_model_type = \"random_forest\"\n",
|
||
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
|
||
"#Определение сетки гиперпараметров\n",
|
||
"param_grid = {\n",
|
||
" #Количество деревьев в лесу\n",
|
||
" \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
|
||
" #Количество столбцов\n",
|
||
" \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
|
||
" #Максимальная глубина дерева\n",
|
||
" \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10],\n",
|
||
" #Критерий оценки качества разбиения\n",
|
||
" \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
|
||
"}\n",
|
||
"#Указываем модель, которую мы хотим оптимизировать\n",
|
||
"#Передаем сетку гиперпараметров для поиска\n",
|
||
"#Указываем, что хотим использовать все доступные процессоры для параллельной обработки\n",
|
||
"gs_optomizer = GridSearchCV(\n",
|
||
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
|
||
")\n",
|
||
"#Метод fit запускает процесс поиска по сетке и обучает модель с различными комбинациями гиперпараметров\n",
|
||
"#метод ravel() для преобразования меток в одномерный массив\n",
|
||
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
|
||
"#выводим наилучшие значения гиперпараметров\n",
|
||
"gs_optomizer.best_params_"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Обучение модели с новыми гиперпараметрами"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 210,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"optimized_model = ensemble.RandomForestClassifier(\n",
|
||
" random_state=random_state,\n",
|
||
" criterion=\"gini\",\n",
|
||
" max_depth=7,\n",
|
||
" max_features=\"sqrt\",\n",
|
||
" n_estimators=30,\n",
|
||
")\n",
|
||
"\n",
|
||
"result = {}\n",
|
||
"\n",
|
||
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
|
||
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
|
||
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
|
||
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
|
||
"\n",
|
||
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
|
||
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
|
||
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
|
||
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
|
||
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
|
||
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
|
||
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
|
||
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
|
||
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
|
||
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
|
||
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
|
||
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Формирование данных для оценки старой и новой версии модели"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
|
||
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
||
" data=class_models[optimized_model_type]\n",
|
||
")\n",
|
||
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
||
" data=result\n",
|
||
")\n",
|
||
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
|
||
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Оценка параметров старой и новой модели"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 212,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_703b0_row0_col0, #T_703b0_row0_col1, #T_703b0_row0_col2, #T_703b0_row0_col3, #T_703b0_row1_col0, #T_703b0_row1_col1, #T_703b0_row1_col2, #T_703b0_row1_col3 {\n",
|
||
" background-color: #440154;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_703b0_row0_col4, #T_703b0_row0_col5, #T_703b0_row0_col6, #T_703b0_row0_col7, #T_703b0_row1_col4, #T_703b0_row1_col5, #T_703b0_row1_col6, #T_703b0_row1_col7 {\n",
|
||
" background-color: #0d0887;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_703b0\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_703b0_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
||
" <th id=\"T_703b0_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
||
" <th id=\"T_703b0_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
||
" <th id=\"T_703b0_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
||
" <th id=\"T_703b0_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
||
" <th id=\"T_703b0_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_703b0_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
||
" <th id=\"T_703b0_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th class=\"index_name level0\" >Name</th>\n",
|
||
" <th class=\"blank col0\" > </th>\n",
|
||
" <th class=\"blank col1\" > </th>\n",
|
||
" <th class=\"blank col2\" > </th>\n",
|
||
" <th class=\"blank col3\" > </th>\n",
|
||
" <th class=\"blank col4\" > </th>\n",
|
||
" <th class=\"blank col5\" > </th>\n",
|
||
" <th class=\"blank col6\" > </th>\n",
|
||
" <th class=\"blank col7\" > </th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_703b0_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
||
" <td id=\"T_703b0_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_703b0_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_703b0_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_703b0_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_703b0_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_703b0_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_703b0_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_703b0_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_703b0_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
||
" <td id=\"T_703b0_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_703b0_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_703b0_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_703b0_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_703b0_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_703b0_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_703b0_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_703b0_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x21f4cb4fe90>"
|
||
]
|
||
},
|
||
"execution_count": 212,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"optimized_metrics[\n",
|
||
" [\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" \"Accuracy_train\",\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_train\",\n",
|
||
" \"F1_test\",\n",
|
||
" ]\n",
|
||
"].style.background_gradient(\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 213,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_0ba78_row0_col0, #T_0ba78_row0_col1, #T_0ba78_row1_col0, #T_0ba78_row1_col1 {\n",
|
||
" background-color: #440154;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_0ba78_row0_col2, #T_0ba78_row0_col3, #T_0ba78_row0_col4, #T_0ba78_row1_col2, #T_0ba78_row1_col3, #T_0ba78_row1_col4 {\n",
|
||
" background-color: #0d0887;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_0ba78\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_0ba78_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_0ba78_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
||
" <th id=\"T_0ba78_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
||
" <th id=\"T_0ba78_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
||
" <th id=\"T_0ba78_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th class=\"index_name level0\" >Name</th>\n",
|
||
" <th class=\"blank col0\" > </th>\n",
|
||
" <th class=\"blank col1\" > </th>\n",
|
||
" <th class=\"blank col2\" > </th>\n",
|
||
" <th class=\"blank col3\" > </th>\n",
|
||
" <th class=\"blank col4\" > </th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_0ba78_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
||
" <td id=\"T_0ba78_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_0ba78_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_0ba78_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_0ba78_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_0ba78_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_0ba78_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
||
" <td id=\"T_0ba78_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_0ba78_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_0ba78_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_0ba78_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_0ba78_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x21f4cb4ef30>"
|
||
]
|
||
},
|
||
"execution_count": 213,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"optimized_metrics[\n",
|
||
" [\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_test\",\n",
|
||
" \"ROC_AUC_test\",\n",
|
||
" \"Cohen_kappa_test\",\n",
|
||
" \"MCC_test\",\n",
|
||
" ]\n",
|
||
"].style.background_gradient(\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" subset=[\n",
|
||
" \"ROC_AUC_test\",\n",
|
||
" \"MCC_test\",\n",
|
||
" \"Cohen_kappa_test\",\n",
|
||
" ],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 215,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 1000x400 with 4 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
|
||
")\n",
|
||
"\n",
|
||
"for index in range(0, len(optimized_metrics)):\n",
|
||
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
|
||
" disp = ConfusionMatrixDisplay(\n",
|
||
" confusion_matrix=c_matrix, display_labels=[\"Less\", \"More\"]\n",
|
||
" ).plot(ax=ax.flat[index])\n",
|
||
"\n",
|
||
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
|
||
"plt.show()"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.12.5"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|