4368 lines
295 KiB
Plaintext
4368 lines
295 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Классификация"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Загрузка набора данных"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Survived</th>\n",
|
|||
|
" <th>Pclass</th>\n",
|
|||
|
" <th>Name</th>\n",
|
|||
|
" <th>Sex</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>SibSp</th>\n",
|
|||
|
" <th>Parch</th>\n",
|
|||
|
" <th>Ticket</th>\n",
|
|||
|
" <th>Fare</th>\n",
|
|||
|
" <th>Cabin</th>\n",
|
|||
|
" <th>Embarked</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>PassengerId</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Braund, Mr. Owen Harris</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>22.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>A/5 21171</td>\n",
|
|||
|
" <td>7.2500</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>38.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>PC 17599</td>\n",
|
|||
|
" <td>71.2833</td>\n",
|
|||
|
" <td>C85</td>\n",
|
|||
|
" <td>C</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Heikkinen, Miss. Laina</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>26.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>STON/O2. 3101282</td>\n",
|
|||
|
" <td>7.9250</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>35.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>113803</td>\n",
|
|||
|
" <td>53.1000</td>\n",
|
|||
|
" <td>C123</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>5</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Allen, Mr. William Henry</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>35.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>373450</td>\n",
|
|||
|
" <td>8.0500</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>887</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>Montvila, Rev. Juozas</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>27.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>211536</td>\n",
|
|||
|
" <td>13.0000</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>888</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Graham, Miss. Margaret Edith</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>19.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>112053</td>\n",
|
|||
|
" <td>30.0000</td>\n",
|
|||
|
" <td>B42</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>889</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Johnston, Miss. Catherine Helen \"Carrie\"</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>W./C. 6607</td>\n",
|
|||
|
" <td>23.4500</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>890</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Behr, Mr. Karl Howell</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>26.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>111369</td>\n",
|
|||
|
" <td>30.0000</td>\n",
|
|||
|
" <td>C148</td>\n",
|
|||
|
" <td>C</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>891</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Dooley, Mr. Patrick</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>32.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>370376</td>\n",
|
|||
|
" <td>7.7500</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>Q</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>891 rows × 11 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Survived Pclass \\\n",
|
|||
|
"PassengerId \n",
|
|||
|
"1 0 3 \n",
|
|||
|
"2 1 1 \n",
|
|||
|
"3 1 3 \n",
|
|||
|
"4 1 1 \n",
|
|||
|
"5 0 3 \n",
|
|||
|
"... ... ... \n",
|
|||
|
"887 0 2 \n",
|
|||
|
"888 1 1 \n",
|
|||
|
"889 0 3 \n",
|
|||
|
"890 1 1 \n",
|
|||
|
"891 0 3 \n",
|
|||
|
"\n",
|
|||
|
" Name Sex Age \\\n",
|
|||
|
"PassengerId \n",
|
|||
|
"1 Braund, Mr. Owen Harris male 22.0 \n",
|
|||
|
"2 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 \n",
|
|||
|
"3 Heikkinen, Miss. Laina female 26.0 \n",
|
|||
|
"4 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 \n",
|
|||
|
"5 Allen, Mr. William Henry male 35.0 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"887 Montvila, Rev. Juozas male 27.0 \n",
|
|||
|
"888 Graham, Miss. Margaret Edith female 19.0 \n",
|
|||
|
"889 Johnston, Miss. Catherine Helen \"Carrie\" female NaN \n",
|
|||
|
"890 Behr, Mr. Karl Howell male 26.0 \n",
|
|||
|
"891 Dooley, Mr. Patrick male 32.0 \n",
|
|||
|
"\n",
|
|||
|
" SibSp Parch Ticket Fare Cabin Embarked \n",
|
|||
|
"PassengerId \n",
|
|||
|
"1 1 0 A/5 21171 7.2500 NaN S \n",
|
|||
|
"2 1 0 PC 17599 71.2833 C85 C \n",
|
|||
|
"3 0 0 STON/O2. 3101282 7.9250 NaN S \n",
|
|||
|
"4 1 0 113803 53.1000 C123 S \n",
|
|||
|
"5 0 0 373450 8.0500 NaN S \n",
|
|||
|
"... ... ... ... ... ... ... \n",
|
|||
|
"887 0 0 211536 13.0000 NaN S \n",
|
|||
|
"888 0 0 112053 30.0000 B42 S \n",
|
|||
|
"889 1 2 W./C. 6607 23.4500 NaN S \n",
|
|||
|
"890 0 0 111369 30.0000 C148 C \n",
|
|||
|
"891 0 0 370376 7.7500 NaN Q \n",
|
|||
|
"\n",
|
|||
|
"[891 rows x 11 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"\n",
|
|||
|
"from sklearn import set_config\n",
|
|||
|
"\n",
|
|||
|
"set_config(transform_output=\"pandas\")\n",
|
|||
|
"\n",
|
|||
|
"random_state=9\n",
|
|||
|
"\n",
|
|||
|
"df = pd.read_csv(\"data/titanic.csv\", index_col=\"PassengerId\")\n",
|
|||
|
"\n",
|
|||
|
"df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
|
|||
|
"\n",
|
|||
|
"Целевой признак -- Survived"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Survived</th>\n",
|
|||
|
" <th>Pclass</th>\n",
|
|||
|
" <th>Name</th>\n",
|
|||
|
" <th>Sex</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>SibSp</th>\n",
|
|||
|
" <th>Parch</th>\n",
|
|||
|
" <th>Ticket</th>\n",
|
|||
|
" <th>Fare</th>\n",
|
|||
|
" <th>Cabin</th>\n",
|
|||
|
" <th>Embarked</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>PassengerId</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>145</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>Andrew, Mr. Edgardo Samuel</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>18.00</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>231945</td>\n",
|
|||
|
" <td>11.5000</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>206</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Strom, Miss. Telma Matilda</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>2.00</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>347054</td>\n",
|
|||
|
" <td>10.4625</td>\n",
|
|||
|
" <td>G6</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>349</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Coutts, Master. William Loch \"William\"</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>3.00</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>C.A. 37671</td>\n",
|
|||
|
" <td>15.9000</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>329</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Goldsmith, Mrs. Frank John (Emily Alice Brown)</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>31.00</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>363291</td>\n",
|
|||
|
" <td>20.5250</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>289</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>Hosono, Mr. Masabumi</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>42.00</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>237798</td>\n",
|
|||
|
" <td>13.0000</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>756</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>Hamalainen, Master. Viljo</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>0.67</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>250649</td>\n",
|
|||
|
" <td>14.5000</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>816</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Fry, Mr. Richard</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>112058</td>\n",
|
|||
|
" <td>0.0000</td>\n",
|
|||
|
" <td>B102</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>890</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Behr, Mr. Karl Howell</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>26.00</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>111369</td>\n",
|
|||
|
" <td>30.0000</td>\n",
|
|||
|
" <td>C148</td>\n",
|
|||
|
" <td>C</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>738</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Lesurer, Mr. Gustave J</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>35.00</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>PC 17755</td>\n",
|
|||
|
" <td>512.3292</td>\n",
|
|||
|
" <td>B101</td>\n",
|
|||
|
" <td>C</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>61</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Sirayanian, Mr. Orsen</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>22.00</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2669</td>\n",
|
|||
|
" <td>7.2292</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>C</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>712 rows × 11 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Survived Pclass Name \\\n",
|
|||
|
"PassengerId \n",
|
|||
|
"145 0 2 Andrew, Mr. Edgardo Samuel \n",
|
|||
|
"206 0 3 Strom, Miss. Telma Matilda \n",
|
|||
|
"349 1 3 Coutts, Master. William Loch \"William\" \n",
|
|||
|
"329 1 3 Goldsmith, Mrs. Frank John (Emily Alice Brown) \n",
|
|||
|
"289 1 2 Hosono, Mr. Masabumi \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"756 1 2 Hamalainen, Master. Viljo \n",
|
|||
|
"816 0 1 Fry, Mr. Richard \n",
|
|||
|
"890 1 1 Behr, Mr. Karl Howell \n",
|
|||
|
"738 1 1 Lesurer, Mr. Gustave J \n",
|
|||
|
"61 0 3 Sirayanian, Mr. Orsen \n",
|
|||
|
"\n",
|
|||
|
" Sex Age SibSp Parch Ticket Fare Cabin Embarked \n",
|
|||
|
"PassengerId \n",
|
|||
|
"145 male 18.00 0 0 231945 11.5000 NaN S \n",
|
|||
|
"206 female 2.00 0 1 347054 10.4625 G6 S \n",
|
|||
|
"349 male 3.00 1 1 C.A. 37671 15.9000 NaN S \n",
|
|||
|
"329 female 31.00 1 1 363291 20.5250 NaN S \n",
|
|||
|
"289 male 42.00 0 0 237798 13.0000 NaN S \n",
|
|||
|
"... ... ... ... ... ... ... ... ... \n",
|
|||
|
"756 male 0.67 1 1 250649 14.5000 NaN S \n",
|
|||
|
"816 male NaN 0 0 112058 0.0000 B102 S \n",
|
|||
|
"890 male 26.00 0 0 111369 30.0000 C148 C \n",
|
|||
|
"738 male 35.00 0 0 PC 17755 512.3292 B101 C \n",
|
|||
|
"61 male 22.00 0 0 2669 7.2292 NaN C \n",
|
|||
|
"\n",
|
|||
|
"[712 rows x 11 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Survived</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>PassengerId</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>145</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>206</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>349</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>329</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>289</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>756</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>816</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>890</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>738</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>61</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>712 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Survived\n",
|
|||
|
"PassengerId \n",
|
|||
|
"145 0\n",
|
|||
|
"206 0\n",
|
|||
|
"349 1\n",
|
|||
|
"329 1\n",
|
|||
|
"289 1\n",
|
|||
|
"... ...\n",
|
|||
|
"756 1\n",
|
|||
|
"816 0\n",
|
|||
|
"890 1\n",
|
|||
|
"738 1\n",
|
|||
|
"61 0\n",
|
|||
|
"\n",
|
|||
|
"[712 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Survived</th>\n",
|
|||
|
" <th>Pclass</th>\n",
|
|||
|
" <th>Name</th>\n",
|
|||
|
" <th>Sex</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>SibSp</th>\n",
|
|||
|
" <th>Parch</th>\n",
|
|||
|
" <th>Ticket</th>\n",
|
|||
|
" <th>Fare</th>\n",
|
|||
|
" <th>Cabin</th>\n",
|
|||
|
" <th>Embarked</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>PassengerId</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>843</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Serepeca, Miss. Augusta</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>30.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>113798</td>\n",
|
|||
|
" <td>31.0000</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>C</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>791</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Keane, Mr. Andrew \"Andy\"</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>12460</td>\n",
|
|||
|
" <td>7.7500</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>Q</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>509</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Olsen, Mr. Henry Margido</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>28.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>C 4001</td>\n",
|
|||
|
" <td>22.5250</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>828</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>Mallet, Master. Andre</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>S.C./PARIS 2079</td>\n",
|
|||
|
" <td>37.0042</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>C</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>414</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>Cunningham, Mr. Alfred Fleming</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>239853</td>\n",
|
|||
|
" <td>0.0000</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>824</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Moor, Mrs. (Beila)</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>27.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>392096</td>\n",
|
|||
|
" <td>12.4750</td>\n",
|
|||
|
" <td>E121</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>353</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Elias, Mr. Tannous</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>15.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>2695</td>\n",
|
|||
|
" <td>7.2292</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>C</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>674</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>Wilhelms, Mr. Charles</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>31.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>244270</td>\n",
|
|||
|
" <td>13.0000</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>100</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>Kantor, Mr. Sinai</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>34.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>244367</td>\n",
|
|||
|
" <td>26.0000</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>542</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Andersson, Miss. Ingeborg Constanzia</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>9.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>347082</td>\n",
|
|||
|
" <td>31.2750</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>179 rows × 11 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Survived Pclass Name Sex \\\n",
|
|||
|
"PassengerId \n",
|
|||
|
"843 1 1 Serepeca, Miss. Augusta female \n",
|
|||
|
"791 0 3 Keane, Mr. Andrew \"Andy\" male \n",
|
|||
|
"509 0 3 Olsen, Mr. Henry Margido male \n",
|
|||
|
"828 1 2 Mallet, Master. Andre male \n",
|
|||
|
"414 0 2 Cunningham, Mr. Alfred Fleming male \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"824 1 3 Moor, Mrs. (Beila) female \n",
|
|||
|
"353 0 3 Elias, Mr. Tannous male \n",
|
|||
|
"674 1 2 Wilhelms, Mr. Charles male \n",
|
|||
|
"100 0 2 Kantor, Mr. Sinai male \n",
|
|||
|
"542 0 3 Andersson, Miss. Ingeborg Constanzia female \n",
|
|||
|
"\n",
|
|||
|
" Age SibSp Parch Ticket Fare Cabin Embarked \n",
|
|||
|
"PassengerId \n",
|
|||
|
"843 30.0 0 0 113798 31.0000 NaN C \n",
|
|||
|
"791 NaN 0 0 12460 7.7500 NaN Q \n",
|
|||
|
"509 28.0 0 0 C 4001 22.5250 NaN S \n",
|
|||
|
"828 1.0 0 2 S.C./PARIS 2079 37.0042 NaN C \n",
|
|||
|
"414 NaN 0 0 239853 0.0000 NaN S \n",
|
|||
|
"... ... ... ... ... ... ... ... \n",
|
|||
|
"824 27.0 0 1 392096 12.4750 E121 S \n",
|
|||
|
"353 15.0 1 1 2695 7.2292 NaN C \n",
|
|||
|
"674 31.0 0 0 244270 13.0000 NaN S \n",
|
|||
|
"100 34.0 1 0 244367 26.0000 NaN S \n",
|
|||
|
"542 9.0 4 2 347082 31.2750 NaN S \n",
|
|||
|
"\n",
|
|||
|
"[179 rows x 11 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Survived</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>PassengerId</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>843</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>791</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>509</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>828</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>414</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>824</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>353</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>674</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>100</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>542</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>179 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Survived\n",
|
|||
|
"PassengerId \n",
|
|||
|
"843 1\n",
|
|||
|
"791 0\n",
|
|||
|
"509 0\n",
|
|||
|
"828 1\n",
|
|||
|
"414 0\n",
|
|||
|
"... ...\n",
|
|||
|
"824 1\n",
|
|||
|
"353 0\n",
|
|||
|
"674 1\n",
|
|||
|
"100 0\n",
|
|||
|
"542 0\n",
|
|||
|
"\n",
|
|||
|
"[179 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from src.utils import split_stratified_into_train_val_test\n",
|
|||
|
"\n",
|
|||
|
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
|
|||
|
" df, stratify_colname=\"Survived\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_train\", X_train)\n",
|
|||
|
"display(\"y_train\", y_train)\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_test\", X_test)\n",
|
|||
|
"display(\"y_test\", y_test)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Формирование конвейера для классификации данных\n",
|
|||
|
"\n",
|
|||
|
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
|
|||
|
"\n",
|
|||
|
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
|
|||
|
"\n",
|
|||
|
"features_preprocessing -- трансформер для предобработки признаков\n",
|
|||
|
"\n",
|
|||
|
"features_engineering -- трансформер для конструирования признаков\n",
|
|||
|
"\n",
|
|||
|
"drop_columns -- трансформер для удаления колонок\n",
|
|||
|
"\n",
|
|||
|
"features_postprocessing -- трансформер для унитарного кодирования новых признаков\n",
|
|||
|
"\n",
|
|||
|
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков\n",
|
|||
|
"\n",
|
|||
|
"Конвейер выполняется последовательно.\n",
|
|||
|
"\n",
|
|||
|
"Трансформер выполняет параллельно для указанного набора колонок.\n",
|
|||
|
"\n",
|
|||
|
"Документация: \n",
|
|||
|
"\n",
|
|||
|
"https://scikit-learn.org/1.5/api/sklearn.pipeline.html\n",
|
|||
|
"\n",
|
|||
|
"https://scikit-learn.org/1.5/modules/generated/sklearn.compose.ColumnTransformer.html#sklearn.compose.ColumnTransformer"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.discriminant_analysis import StandardScaler\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.preprocessing import OneHotEncoder\n",
|
|||
|
"\n",
|
|||
|
"from src.transformers import TitanicFeatures\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"columns_to_drop = [\"Survived\", \"Name\", \"Cabin\", \"Ticket\", \"Embarked\", \"Parch\", \"Fare\"]\n",
|
|||
|
"num_columns = [\n",
|
|||
|
" column\n",
|
|||
|
" for column in df.columns\n",
|
|||
|
" if column not in columns_to_drop and df[column].dtype != \"object\"\n",
|
|||
|
"]\n",
|
|||
|
"cat_columns = [\n",
|
|||
|
" column\n",
|
|||
|
" for column in df.columns\n",
|
|||
|
" if column not in columns_to_drop and df[column].dtype == \"object\"\n",
|
|||
|
"]\n",
|
|||
|
"\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" (\"prepocessing_features\", cat_imputer, [\"Name\", \"Cabin\"]),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\"\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"features_engineering = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"add_features\", TitanicFeatures(), [\"Name\", \"Cabin\"]),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\",\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"drop_columns = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\",\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"features_postprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"prepocessing_cat\", preprocessing_cat, [\"Cabin_type\"]),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\",\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"pipeline_end = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" (\"features_engineering\", features_engineering),\n",
|
|||
|
" (\"drop_columns\", drop_columns),\n",
|
|||
|
" (\"features_postprocessing\", features_postprocessing),\n",
|
|||
|
" ]\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Демонстрация работы конвейера для предобработки данных при классификации"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Cabin_type_B</th>\n",
|
|||
|
" <th>Cabin_type_C</th>\n",
|
|||
|
" <th>Cabin_type_D</th>\n",
|
|||
|
" <th>Cabin_type_E</th>\n",
|
|||
|
" <th>Cabin_type_F</th>\n",
|
|||
|
" <th>Cabin_type_G</th>\n",
|
|||
|
" <th>Cabin_type_T</th>\n",
|
|||
|
" <th>Cabin_type_u</th>\n",
|
|||
|
" <th>Is_married</th>\n",
|
|||
|
" <th>Pclass</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>SibSp</th>\n",
|
|||
|
" <th>Sex_male</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>PassengerId</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>145</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>-0.379423</td>\n",
|
|||
|
" <td>-0.869506</td>\n",
|
|||
|
" <td>-0.473465</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>206</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0.821241</td>\n",
|
|||
|
" <td>-2.102186</td>\n",
|
|||
|
" <td>-0.473465</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>349</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0.821241</td>\n",
|
|||
|
" <td>-2.025143</td>\n",
|
|||
|
" <td>0.437635</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>329</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.821241</td>\n",
|
|||
|
" <td>0.132047</td>\n",
|
|||
|
" <td>0.437635</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>289</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>-0.379423</td>\n",
|
|||
|
" <td>0.979514</td>\n",
|
|||
|
" <td>-0.473465</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>756</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>-0.379423</td>\n",
|
|||
|
" <td>-2.204652</td>\n",
|
|||
|
" <td>0.437635</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>816</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>-1.580088</td>\n",
|
|||
|
" <td>-0.099081</td>\n",
|
|||
|
" <td>-0.473465</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>890</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>-1.580088</td>\n",
|
|||
|
" <td>-0.253166</td>\n",
|
|||
|
" <td>-0.473465</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>738</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>-1.580088</td>\n",
|
|||
|
" <td>0.440217</td>\n",
|
|||
|
" <td>-0.473465</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>61</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0.821241</td>\n",
|
|||
|
" <td>-0.561336</td>\n",
|
|||
|
" <td>-0.473465</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>712 rows × 13 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Cabin_type_B Cabin_type_C Cabin_type_D Cabin_type_E \\\n",
|
|||
|
"PassengerId \n",
|
|||
|
"145 0.0 0.0 0.0 0.0 \n",
|
|||
|
"206 0.0 0.0 0.0 0.0 \n",
|
|||
|
"349 0.0 0.0 0.0 0.0 \n",
|
|||
|
"329 0.0 0.0 0.0 0.0 \n",
|
|||
|
"289 0.0 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"756 0.0 0.0 0.0 0.0 \n",
|
|||
|
"816 1.0 0.0 0.0 0.0 \n",
|
|||
|
"890 0.0 1.0 0.0 0.0 \n",
|
|||
|
"738 1.0 0.0 0.0 0.0 \n",
|
|||
|
"61 0.0 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Cabin_type_F Cabin_type_G Cabin_type_T Cabin_type_u \\\n",
|
|||
|
"PassengerId \n",
|
|||
|
"145 0.0 0.0 0.0 1.0 \n",
|
|||
|
"206 0.0 1.0 0.0 0.0 \n",
|
|||
|
"349 0.0 0.0 0.0 1.0 \n",
|
|||
|
"329 0.0 0.0 0.0 1.0 \n",
|
|||
|
"289 0.0 0.0 0.0 1.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"756 0.0 0.0 0.0 1.0 \n",
|
|||
|
"816 0.0 0.0 0.0 0.0 \n",
|
|||
|
"890 0.0 0.0 0.0 0.0 \n",
|
|||
|
"738 0.0 0.0 0.0 0.0 \n",
|
|||
|
"61 0.0 0.0 0.0 1.0 \n",
|
|||
|
"\n",
|
|||
|
" Is_married Pclass Age SibSp Sex_male \n",
|
|||
|
"PassengerId \n",
|
|||
|
"145 0 -0.379423 -0.869506 -0.473465 1.0 \n",
|
|||
|
"206 0 0.821241 -2.102186 -0.473465 0.0 \n",
|
|||
|
"349 0 0.821241 -2.025143 0.437635 1.0 \n",
|
|||
|
"329 1 0.821241 0.132047 0.437635 0.0 \n",
|
|||
|
"289 0 -0.379423 0.979514 -0.473465 1.0 \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"756 0 -0.379423 -2.204652 0.437635 1.0 \n",
|
|||
|
"816 0 -1.580088 -0.099081 -0.473465 1.0 \n",
|
|||
|
"890 0 -1.580088 -0.253166 -0.473465 1.0 \n",
|
|||
|
"738 0 -1.580088 0.440217 -0.473465 1.0 \n",
|
|||
|
"61 0 0.821241 -0.561336 -0.473465 1.0 \n",
|
|||
|
"\n",
|
|||
|
"[712 rows x 13 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"preprocessed_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Формирование набора моделей для классификации\n",
|
|||
|
"\n",
|
|||
|
"logistic -- логистическая регрессия\n",
|
|||
|
"\n",
|
|||
|
"ridge -- гребневая регрессия\n",
|
|||
|
"\n",
|
|||
|
"decision_tree -- дерево решений\n",
|
|||
|
"\n",
|
|||
|
"knn -- k-ближайших соседей\n",
|
|||
|
"\n",
|
|||
|
"naive_bayes -- наивный Байесовский классификатор\n",
|
|||
|
"\n",
|
|||
|
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
|
|||
|
"\n",
|
|||
|
"random_forest -- метод случайного леса (набор деревьев решений)\n",
|
|||
|
"\n",
|
|||
|
"mlp -- многослойный персептрон (нейронная сеть)\n",
|
|||
|
"\n",
|
|||
|
"Документация: https://scikit-learn.org/1.5/supervised_learning.html"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
|
|||
|
"\n",
|
|||
|
"class_models = {\n",
|
|||
|
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
|
|||
|
" # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
|
|||
|
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
|
|||
|
" \"decision_tree\": {\n",
|
|||
|
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
|
|||
|
" },\n",
|
|||
|
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
|
|||
|
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
|
|||
|
" \"gradient_boosting\": {\n",
|
|||
|
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
|
|||
|
" },\n",
|
|||
|
" \"random_forest\": {\n",
|
|||
|
" \"model\": ensemble.RandomForestClassifier(\n",
|
|||
|
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" \"mlp\": {\n",
|
|||
|
" \"model\": neural_network.MLPClassifier(\n",
|
|||
|
" hidden_layer_sizes=(7,),\n",
|
|||
|
" max_iter=500,\n",
|
|||
|
" early_stopping=True,\n",
|
|||
|
" random_state=random_state,\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
"}"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Обучение моделей на обучающем наборе данных и оценка на тестовом"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: logistic\n",
|
|||
|
"Model: ridge\n",
|
|||
|
"Model: decision_tree\n",
|
|||
|
"Model: knn\n",
|
|||
|
"Model: naive_bayes\n",
|
|||
|
"Model: gradient_boosting\n",
|
|||
|
"Model: random_forest\n",
|
|||
|
"Model: mlp\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from src.utils import run_classification\n",
|
|||
|
"\n",
|
|||
|
"for model_name in class_models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" model = class_models[model_name][\"model\"]\n",
|
|||
|
"\n",
|
|||
|
" pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)]).fit(\n",
|
|||
|
" X_train, y_train.values.ravel()\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" class_models[model_name] = run_classification(\n",
|
|||
|
" pipeline, X_train, X_test, y_train, y_test\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Сводная таблица оценок качества для использованных моделей классификации\n",
|
|||
|
"\n",
|
|||
|
"Документация: https://scikit-learn.org/1.5/modules/model_evaluation.html"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Матрица неточностей"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1QAAAQ9CAYAAABePQxBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/GU6VOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVhUZfsH8O9hR3YQWQQRFfd9yXBNRVFTIUnT6OeSS6+a62supQi4UylpqW1ub5qZJZmmZu5blmsuuIOisqgICMo2c35/0IxNMAMDM8yc4fu5rnMVz3PmzH1GnZv7nOc8jyCKoggiIiIiIiLSmpmhAyAiIiIiIpIqFlRERERERETlxIKKiIiIiIionFhQERERERERlRMLKiIiIiIionJiQUVERERERFROLKiIiIiIiIjKiQUVERERERFRObGgIiIiIiIiKicWVGRU1q9fD0EQkJiYqJfjJyYmQhAErF+/XifHO3ToEARBwKFDh3RyPCIiIlMRGRkJQRDKtK8gCIiMjNRvQER6woKKqAxWrVqlsyKMiIiIiEyHhaEDIKpMfn5+eP78OSwtLbV63apVq1C9enWMGDFCpb1Lly54/vw5rKysdBglERGR9M2ZMwezZs0ydBhEeseCiqoUQRBgY2Ojs+OZmZnp9HhERESmICcnB3Z2drCw4K+aZPo45I+M3qpVq9CkSRNYW1vD29sbEyZMQEZGRrH9PvvsM9SpUwe2trZ46aWXcPToUbzyyit45ZVXlPuU9AxVSkoKRo4cCR8fH1hbW8PLywshISHK57hq166Ny5cv4/DhwxAEAYIgKI+p7hmqU6dOoW/fvnBxcYGdnR2aN2+OTz75RLcfDBERkRFQPCt15coVvPnmm3BxcUGnTp1KfIYqLy8PU6dOhbu7OxwcHDBgwADcu3evxOMeOnQIbdu2hY2NDerWrYvPP/9c7XNZ33zzDdq0aQNbW1u4urpiyJAhSEpK0sv5Ev0bLxuQUYuMjERUVBSCgoIwbtw4XLt2DatXr8aff/6J48ePK4furV69Gu+++y46d+6MqVOnIjExEaGhoXBxcYGPj4/G9wgLC8Ply5cxceJE1K5dG2lpadi3bx/u3r2L2rVrIzY2FhMnToS9vT0++OADAICHh4fa4+3btw/9+vWDl5cXJk+eDE9PT8THx2Pnzp2YPHmy7j4cIiIiIzJo0CAEBARg0aJFEEURaWlpxfYZPXo0vvnmG7z55pvo0KEDDhw4gFdffbXYfufOnUPv3r3h5eWFqKgoyGQyREdHw93dvdi+CxcuxNy5czF48GCMHj0aDx8+xMqVK9GlSxecO3cOzs7O+jhdohdEIiOybt06EYCYkJAgpqWliVZWVmKvXr1EmUym3OfTTz8VAYhr164VRVEU8/LyRDc3N7Fdu3ZiQUGBcr/169eLAMSuXbsq2xISEkQA4rp160RRFMUnT56IAMQPP/xQY1xNmjRROY7CwYMHRQDiwYMHRVEUxcLCQtHf31/08/MTnzx5orKvXC4v+wdBREQkEfPmzRMBiEOHDi2xXeH8+fMiAHH8+PEq+7355psiAHHevHnKtv79+4vVqlUT79+/r2y7ceOGaGFhoXLMxMRE0dzcXFy4cKHKMS9evChaWFgUayfSBw75I6P122+/IT8/H1OmTIGZ2Yu/qmPGjIGjoyN27doFADh9+jQeP36MMWPGqIzVDg8Ph4uLi8b3sLW1hZWVFQ4dOoQnT55UOOZz584hISEBU6ZMKXZFrKxTxxIREUnRf/7zH439v/zyCwBg0qRJKu1TpkxR+Vkmk+G3335DaGgovL29le316tVDnz59VPb98ccfIZfLMXjwYDx69Ei5eXp6IiAgAAcPHqzAGRGVDYf8kdG6c+cOAKBBgwYq7VZWVqhTp46yX/HfevXqqexnYWGB2rVra3wPa2trLF26FP/973/h4eGBl19+Gf369cOwYcPg6empdcy3bt0CADRt2lTr1xIREUmZv7+/xv47d+7AzMwMdevWVWn/d55PS0vD8+fPi+V1oHiuv3HjBkRRREBAQInvqe2svkTlwYKKqrwpU6agf//+iIuLw969ezF37lwsXrwYBw4cQKtWrQwdHhERkSTY2tpW+nvK5XIIgoDdu3fD3Ny8WL+9vX2lx0RVD4f8kdHy8/MDAFy7dk2lPT8/HwkJCcp+xX9v3rypsl9hYaFypr7S1K1bF//973/x66+/4tKlS8jPz8fHH3+s7C/rcD3FVbdLly6VaX8iIqKqws/PD3K5XDmaQ+Hfeb5GjRqwsbEplteB4rm+bt26EEUR/v7+CAoKKra9/PLLuj8Ron9hQUVGKygoCFZWVlixYgVEUVS2f/3118jMzFTOCtS2bVu4ubnhyy+/RGFhoXK/TZs2lfpc1LNnz5Cbm6vSVrduXTg4OCAvL0/ZZmdnV+JU7f/WunVr+Pv7IzY2ttj+/zwHIiKiqkbx/NOKFStU2mNjY1V+Njc3R1BQEOLi4vDgwQNl+82bN7F7926VfQcOHAhzc3NERUUVy7OiKOLx48c6PAOiknHIHxktd3d3zJ49G1FRUejduzcGDBiAa9euYdWqVWjXrh3eeustAEXPVEVGRmLixIno3r07Bg8ejMTERKxfvx5169bVeHfp+vXr6NGjBwYPHozGjRvDwsIC27dvR2pqKoYMGaLcr02bNli9ejUWLFiAevXqoUaNGujevXux45mZmWH16tXo378/WrZsiZEjR8LLywtXr17F5cuXsXfvXt1/UERERBLQsmVLDB06FKtWrUJmZiY6dOiA/fv3l3gnKjIyEr/++is6duyIcePGQSaT4dNPP0XTpk1x/vx55X5169bFggULMHv2bOWSKQ4ODkhISMD27dsxduxYTJ8+vRLPkqoiFlRk1CIjI+Hu7o5PP/0UU6dOhaurK8aOHYtFixapPGj67rvvQhRFfPzxx5g+fTpatGiBHTt2YNKkSbCxsVF7fF9fXwwdOhT79+/H//73P1hYWKBhw4bYunUrwsLClPtFRETgzp07iImJwdOnT9G1a9cSCyoACA4OxsGDBxEVFYWPP/4YcrkcdevWxZgxY3T3wRAREUnQ2rVr4e7ujk2bNiEuLg7du3fHrl274Ovrq7JfmzZtsHv3bkyfPh1z586Fr68voqOjER8fj6tXr6rsO2vWLNSvXx/Lly9HVFQUgKL83qtXLwwYMKDSzo2qLkHkOCQyUXK5HO7u7hg4cCC+/PJLQ4dDREREFRQaGorLly/jxo0bhg6FSInPUJFJyM3NLTZ2euPGjUhPT8crr7ximKCIiIio3J4/f67y840bN/DLL78wr5PR4R0qMgmHDh3C1KlTMWjQILi5ueHs2bP4+uuv0ahRI5w5cwZWVlaGDpGIiIi04OXlhREjRijXnly9ejXy8vJw7tw5tetOERkCn6Eik1C7dm34+vpixYoVSE9Ph6urK4YNG4YlS5awmCIiIpKg3r1749tvv0VKSgqsra0RGBiIRYsWsZgio8M7VEREREREROXEZ6iIiIiIiIjKiQUVERERERFROfEZKgORy+V48OABHBwcNC48S2SqRFHE06dP4e3tDTMz3V7byc3NRX5+vsZ9rKysNK5RRkRVE/MzVWWGzs2ANPMzCyoDefDgQbFF7IiqoqSkJPj4+OjseLm5ufD3s0dKmkzjfp6enkhISJDclzYR6RfzM5HhcjMgzfzMgspAHBwcAAB3ztaGoz1HXhrCa/WbGTqEKq0QBTiGX5T/FnQlPz8fKWky3DztC0eHkv9tZT2Vo17bJOTn50vqC5uI9I/52fBefzXE0CFUWYWyPBy+vcoguRmQbn5mQWUgimEEjvZmGv9ikf5YCJaGDqFq+3t+UX0NqbF3EGDvUPKx5eAwHiIqGfOz4VmYWxs6hCrPELkZkG5+ZkFFRCapQJShQM2qEAWivJKjISIiIk25uahfmvmZBRURmSQ5RMhR8pe2unYiIiLSH025WdEvRSyoiMgkySFCxoKKiIjIaGjKzYp+KWJBRUQmqUCUo0DN97JUhxQQERFJmabcrOiXIhZURGSS5H9v6vqIiIiocmnKzSi
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1200x1000 with 16 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
|
|||
|
"for index, key in enumerate(class_models.keys()):\n",
|
|||
|
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[\"Died\", \"Sirvived\"]\n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
" disp.ax_.set_title(key)\n",
|
|||
|
"\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Точность, полнота, верность (аккуратность), F-мера"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_4d92f_row0_col0, #T_4d92f_row7_col1 {\n",
|
|||
|
" background-color: #a2da37;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row0_col1, #T_4d92f_row3_col2, #T_4d92f_row5_col3 {\n",
|
|||
|
" background-color: #77d153;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row0_col2 {\n",
|
|||
|
" background-color: #a5db36;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row0_col3 {\n",
|
|||
|
" background-color: #7cd250;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row0_col4, #T_4d92f_row0_col5, #T_4d92f_row0_col6, #T_4d92f_row0_col7 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row1_col0 {\n",
|
|||
|
" background-color: #a0da39;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row1_col1 {\n",
|
|||
|
" background-color: #7fd34e;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row1_col2 {\n",
|
|||
|
" background-color: #93d741;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row1_col3 {\n",
|
|||
|
" background-color: #70cf57;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row1_col4 {\n",
|
|||
|
" background-color: #d45270;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row1_col5, #T_4d92f_row1_col6 {\n",
|
|||
|
" background-color: #d6556d;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row1_col7 {\n",
|
|||
|
" background-color: #d7566c;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row2_col0 {\n",
|
|||
|
" background-color: #3bbb75;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row2_col1 {\n",
|
|||
|
" background-color: #84d44b;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row2_col2, #T_4d92f_row4_col0 {\n",
|
|||
|
" background-color: #6ece58;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row2_col3 {\n",
|
|||
|
" background-color: #65cb5e;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row2_col4, #T_4d92f_row5_col4 {\n",
|
|||
|
" background-color: #a62098;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row2_col5, #T_4d92f_row3_col5, #T_4d92f_row4_col5, #T_4d92f_row4_col7 {\n",
|
|||
|
" background-color: #d35171;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row2_col6 {\n",
|
|||
|
" background-color: #c03a83;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row2_col7 {\n",
|
|||
|
" background-color: #d5536f;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row3_col0 {\n",
|
|||
|
" background-color: #81d34d;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row3_col1, #T_4d92f_row6_col2, #T_4d92f_row6_col3, #T_4d92f_row7_col0 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row3_col3 {\n",
|
|||
|
" background-color: #56c667;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row3_col4 {\n",
|
|||
|
" background-color: #c33d80;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row3_col6, #T_4d92f_row5_col7 {\n",
|
|||
|
" background-color: #cc4977;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row3_col7 {\n",
|
|||
|
" background-color: #d24f71;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row4_col1 {\n",
|
|||
|
" background-color: #9bd93c;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row4_col2 {\n",
|
|||
|
" background-color: #6ccd5a;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row4_col3 {\n",
|
|||
|
" background-color: #5cc863;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row4_col4 {\n",
|
|||
|
" background-color: #b83289;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row4_col6 {\n",
|
|||
|
" background-color: #c7427c;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row5_col0 {\n",
|
|||
|
" background-color: #2db27d;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row5_col1 {\n",
|
|||
|
" background-color: #26ad81;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row5_col2 {\n",
|
|||
|
" background-color: #89d548;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row5_col5 {\n",
|
|||
|
" background-color: #ae2892;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row5_col6 {\n",
|
|||
|
" background-color: #c43e7f;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row6_col0, #T_4d92f_row6_col1, #T_4d92f_row7_col2, #T_4d92f_row7_col3 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row6_col4, #T_4d92f_row7_col5, #T_4d92f_row7_col6, #T_4d92f_row7_col7 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row6_col5 {\n",
|
|||
|
" background-color: #6700a8;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row6_col6 {\n",
|
|||
|
" background-color: #b32c8e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row6_col7 {\n",
|
|||
|
" background-color: #c5407e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_4d92f_row7_col4 {\n",
|
|||
|
" background-color: #5002a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_4d92f\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_4d92f_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
|||
|
" <th id=\"T_4d92f_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
|||
|
" <th id=\"T_4d92f_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
|||
|
" <th id=\"T_4d92f_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
|||
|
" <th id=\"T_4d92f_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
|||
|
" <th id=\"T_4d92f_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_4d92f_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
|||
|
" <th id=\"T_4d92f_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_4d92f_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_4d92f_row0_col0\" class=\"data row0 col0\" >0.894340</td>\n",
|
|||
|
" <td id=\"T_4d92f_row0_col1\" class=\"data row0 col1\" >0.794118</td>\n",
|
|||
|
" <td id=\"T_4d92f_row0_col2\" class=\"data row0 col2\" >0.868132</td>\n",
|
|||
|
" <td id=\"T_4d92f_row0_col3\" class=\"data row0 col3\" >0.782609</td>\n",
|
|||
|
" <td id=\"T_4d92f_row0_col4\" class=\"data row0 col4\" >0.910112</td>\n",
|
|||
|
" <td id=\"T_4d92f_row0_col5\" class=\"data row0 col5\" >0.837989</td>\n",
|
|||
|
" <td id=\"T_4d92f_row0_col6\" class=\"data row0 col6\" >0.881041</td>\n",
|
|||
|
" <td id=\"T_4d92f_row0_col7\" class=\"data row0 col7\" >0.788321</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_4d92f_level0_row1\" class=\"row_heading level0 row1\" >gradient_boosting</th>\n",
|
|||
|
" <td id=\"T_4d92f_row1_col0\" class=\"data row1 col0\" >0.889764</td>\n",
|
|||
|
" <td id=\"T_4d92f_row1_col1\" class=\"data row1 col1\" >0.800000</td>\n",
|
|||
|
" <td id=\"T_4d92f_row1_col2\" class=\"data row1 col2\" >0.827839</td>\n",
|
|||
|
" <td id=\"T_4d92f_row1_col3\" class=\"data row1 col3\" >0.753623</td>\n",
|
|||
|
" <td id=\"T_4d92f_row1_col4\" class=\"data row1 col4\" >0.894663</td>\n",
|
|||
|
" <td id=\"T_4d92f_row1_col5\" class=\"data row1 col5\" >0.832402</td>\n",
|
|||
|
" <td id=\"T_4d92f_row1_col6\" class=\"data row1 col6\" >0.857685</td>\n",
|
|||
|
" <td id=\"T_4d92f_row1_col7\" class=\"data row1 col7\" >0.776119</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_4d92f_level0_row2\" class=\"row_heading level0 row2\" >logistic</th>\n",
|
|||
|
" <td id=\"T_4d92f_row2_col0\" class=\"data row2 col0\" >0.751880</td>\n",
|
|||
|
" <td id=\"T_4d92f_row2_col1\" class=\"data row2 col1\" >0.806452</td>\n",
|
|||
|
" <td id=\"T_4d92f_row2_col2\" class=\"data row2 col2\" >0.732601</td>\n",
|
|||
|
" <td id=\"T_4d92f_row2_col3\" class=\"data row2 col3\" >0.724638</td>\n",
|
|||
|
" <td id=\"T_4d92f_row2_col4\" class=\"data row2 col4\" >0.804775</td>\n",
|
|||
|
" <td id=\"T_4d92f_row2_col5\" class=\"data row2 col5\" >0.826816</td>\n",
|
|||
|
" <td id=\"T_4d92f_row2_col6\" class=\"data row2 col6\" >0.742115</td>\n",
|
|||
|
" <td id=\"T_4d92f_row2_col7\" class=\"data row2 col7\" >0.763359</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_4d92f_level0_row3\" class=\"row_heading level0 row3\" >decision_tree</th>\n",
|
|||
|
" <td id=\"T_4d92f_row3_col0\" class=\"data row3 col0\" >0.852459</td>\n",
|
|||
|
" <td id=\"T_4d92f_row3_col1\" class=\"data row3 col1\" >0.839286</td>\n",
|
|||
|
" <td id=\"T_4d92f_row3_col2\" class=\"data row3 col2\" >0.761905</td>\n",
|
|||
|
" <td id=\"T_4d92f_row3_col3\" class=\"data row3 col3\" >0.681159</td>\n",
|
|||
|
" <td id=\"T_4d92f_row3_col4\" class=\"data row3 col4\" >0.858146</td>\n",
|
|||
|
" <td id=\"T_4d92f_row3_col5\" class=\"data row3 col5\" >0.826816</td>\n",
|
|||
|
" <td id=\"T_4d92f_row3_col6\" class=\"data row3 col6\" >0.804642</td>\n",
|
|||
|
" <td id=\"T_4d92f_row3_col7\" class=\"data row3 col7\" >0.752000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_4d92f_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
|
|||
|
" <td id=\"T_4d92f_row4_col0\" class=\"data row4 col0\" >0.829167</td>\n",
|
|||
|
" <td id=\"T_4d92f_row4_col1\" class=\"data row4 col1\" >0.827586</td>\n",
|
|||
|
" <td id=\"T_4d92f_row4_col2\" class=\"data row4 col2\" >0.728938</td>\n",
|
|||
|
" <td id=\"T_4d92f_row4_col3\" class=\"data row4 col3\" >0.695652</td>\n",
|
|||
|
" <td id=\"T_4d92f_row4_col4\" class=\"data row4 col4\" >0.838483</td>\n",
|
|||
|
" <td id=\"T_4d92f_row4_col5\" class=\"data row4 col5\" >0.826816</td>\n",
|
|||
|
" <td id=\"T_4d92f_row4_col6\" class=\"data row4 col6\" >0.775828</td>\n",
|
|||
|
" <td id=\"T_4d92f_row4_col7\" class=\"data row4 col7\" >0.755906</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_4d92f_level0_row5\" class=\"row_heading level0 row5\" >ridge</th>\n",
|
|||
|
" <td id=\"T_4d92f_row5_col0\" class=\"data row5 col0\" >0.720395</td>\n",
|
|||
|
" <td id=\"T_4d92f_row5_col1\" class=\"data row5 col1\" >0.688312</td>\n",
|
|||
|
" <td id=\"T_4d92f_row5_col2\" class=\"data row5 col2\" >0.802198</td>\n",
|
|||
|
" <td id=\"T_4d92f_row5_col3\" class=\"data row5 col3\" >0.768116</td>\n",
|
|||
|
" <td id=\"T_4d92f_row5_col4\" class=\"data row5 col4\" >0.804775</td>\n",
|
|||
|
" <td id=\"T_4d92f_row5_col5\" class=\"data row5 col5\" >0.776536</td>\n",
|
|||
|
" <td id=\"T_4d92f_row5_col6\" class=\"data row5 col6\" >0.759099</td>\n",
|
|||
|
" <td id=\"T_4d92f_row5_col7\" class=\"data row5 col7\" >0.726027</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_4d92f_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
|
|||
|
" <td id=\"T_4d92f_row6_col0\" class=\"data row6 col0\" >0.554524</td>\n",
|
|||
|
" <td id=\"T_4d92f_row6_col1\" class=\"data row6 col1\" >0.575472</td>\n",
|
|||
|
" <td id=\"T_4d92f_row6_col2\" class=\"data row6 col2\" >0.875458</td>\n",
|
|||
|
" <td id=\"T_4d92f_row6_col3\" class=\"data row6 col3\" >0.884058</td>\n",
|
|||
|
" <td id=\"T_4d92f_row6_col4\" class=\"data row6 col4\" >0.682584</td>\n",
|
|||
|
" <td id=\"T_4d92f_row6_col5\" class=\"data row6 col5\" >0.703911</td>\n",
|
|||
|
" <td id=\"T_4d92f_row6_col6\" class=\"data row6 col6\" >0.678977</td>\n",
|
|||
|
" <td id=\"T_4d92f_row6_col7\" class=\"data row6 col7\" >0.697143</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_4d92f_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
|
|||
|
" <td id=\"T_4d92f_row7_col0\" class=\"data row7 col0\" >0.900000</td>\n",
|
|||
|
" <td id=\"T_4d92f_row7_col1\" class=\"data row7 col1\" >0.833333</td>\n",
|
|||
|
" <td id=\"T_4d92f_row7_col2\" class=\"data row7 col2\" >0.197802</td>\n",
|
|||
|
" <td id=\"T_4d92f_row7_col3\" class=\"data row7 col3\" >0.217391</td>\n",
|
|||
|
" <td id=\"T_4d92f_row7_col4\" class=\"data row7 col4\" >0.683989</td>\n",
|
|||
|
" <td id=\"T_4d92f_row7_col5\" class=\"data row7 col5\" >0.681564</td>\n",
|
|||
|
" <td id=\"T_4d92f_row7_col6\" class=\"data row7 col6\" >0.324324</td>\n",
|
|||
|
" <td id=\"T_4d92f_row7_col7\" class=\"data row7 col7\" >0.344828</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x2827fc5e2d0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
|||
|
" [\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" \"Accuracy_train\",\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_train\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]\n",
|
|||
|
"class_metrics.sort_values(\n",
|
|||
|
" by=\"Accuracy_test\", ascending=False\n",
|
|||
|
").style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_71e02_row0_col0, #T_71e02_row0_col1 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row0_col2, #T_71e02_row0_col3, #T_71e02_row0_col4 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row1_col0, #T_71e02_row4_col0, #T_71e02_row4_col1, #T_71e02_row5_col0 {\n",
|
|||
|
" background-color: #93d741;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row1_col1 {\n",
|
|||
|
" background-color: #98d83e;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row1_col2 {\n",
|
|||
|
" background-color: #d7566c;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row1_col3 {\n",
|
|||
|
" background-color: #d45270;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row1_col4, #T_71e02_row4_col4, #T_71e02_row5_col3, #T_71e02_row5_col4 {\n",
|
|||
|
" background-color: #d24f71;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row2_col0 {\n",
|
|||
|
" background-color: #42be71;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row2_col1 {\n",
|
|||
|
" background-color: #7fd34e;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row2_col2, #T_71e02_row3_col2 {\n",
|
|||
|
" background-color: #d5536f;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row2_col3 {\n",
|
|||
|
" background-color: #be3885;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row2_col4 {\n",
|
|||
|
" background-color: #b6308b;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row3_col0 {\n",
|
|||
|
" background-color: #9dd93b;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row3_col1 {\n",
|
|||
|
" background-color: #a0da39;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row3_col3, #T_71e02_row3_col4 {\n",
|
|||
|
" background-color: #d6556d;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row4_col2 {\n",
|
|||
|
" background-color: #cc4977;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row4_col3 {\n",
|
|||
|
" background-color: #d35171;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row5_col1 {\n",
|
|||
|
" background-color: #90d743;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row5_col2 {\n",
|
|||
|
" background-color: #a82296;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row6_col0 {\n",
|
|||
|
" background-color: #21908d;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row6_col1 {\n",
|
|||
|
" background-color: #6ece58;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row6_col2 {\n",
|
|||
|
" background-color: #a11b9b;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row6_col3 {\n",
|
|||
|
" background-color: #9e199d;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row6_col4 {\n",
|
|||
|
" background-color: #9c179e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row7_col0, #T_71e02_row7_col1 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_71e02_row7_col2, #T_71e02_row7_col3, #T_71e02_row7_col4 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_71e02\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_71e02_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_71e02_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
|||
|
" <th id=\"T_71e02_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
|||
|
" <th id=\"T_71e02_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
|||
|
" <th id=\"T_71e02_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_71e02_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_71e02_row0_col0\" class=\"data row0 col0\" >0.837989</td>\n",
|
|||
|
" <td id=\"T_71e02_row0_col1\" class=\"data row0 col1\" >0.788321</td>\n",
|
|||
|
" <td id=\"T_71e02_row0_col2\" class=\"data row0 col2\" >0.858893</td>\n",
|
|||
|
" <td id=\"T_71e02_row0_col3\" class=\"data row0 col3\" >0.657111</td>\n",
|
|||
|
" <td id=\"T_71e02_row0_col4\" class=\"data row0 col4\" >0.657157</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_71e02_level0_row1\" class=\"row_heading level0 row1\" >logistic</th>\n",
|
|||
|
" <td id=\"T_71e02_row1_col0\" class=\"data row1 col0\" >0.826816</td>\n",
|
|||
|
" <td id=\"T_71e02_row1_col1\" class=\"data row1 col1\" >0.763359</td>\n",
|
|||
|
" <td id=\"T_71e02_row1_col2\" class=\"data row1 col2\" >0.854084</td>\n",
|
|||
|
" <td id=\"T_71e02_row1_col3\" class=\"data row1 col3\" >0.627409</td>\n",
|
|||
|
" <td id=\"T_71e02_row1_col4\" class=\"data row1 col4\" >0.629641</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_71e02_level0_row2\" class=\"row_heading level0 row2\" >ridge</th>\n",
|
|||
|
" <td id=\"T_71e02_row2_col0\" class=\"data row2 col0\" >0.776536</td>\n",
|
|||
|
" <td id=\"T_71e02_row2_col1\" class=\"data row2 col1\" >0.726027</td>\n",
|
|||
|
" <td id=\"T_71e02_row2_col2\" class=\"data row2 col2\" >0.851054</td>\n",
|
|||
|
" <td id=\"T_71e02_row2_col3\" class=\"data row2 col3\" >0.538303</td>\n",
|
|||
|
" <td id=\"T_71e02_row2_col4\" class=\"data row2 col4\" >0.540613</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_71e02_level0_row3\" class=\"row_heading level0 row3\" >gradient_boosting</th>\n",
|
|||
|
" <td id=\"T_71e02_row3_col0\" class=\"data row3 col0\" >0.832402</td>\n",
|
|||
|
" <td id=\"T_71e02_row3_col1\" class=\"data row3 col1\" >0.776119</td>\n",
|
|||
|
" <td id=\"T_71e02_row3_col2\" class=\"data row3 col2\" >0.850922</td>\n",
|
|||
|
" <td id=\"T_71e02_row3_col3\" class=\"data row3 col3\" >0.642381</td>\n",
|
|||
|
" <td id=\"T_71e02_row3_col4\" class=\"data row3 col4\" >0.643113</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_71e02_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
|
|||
|
" <td id=\"T_71e02_row4_col0\" class=\"data row4 col0\" >0.826816</td>\n",
|
|||
|
" <td id=\"T_71e02_row4_col1\" class=\"data row4 col1\" >0.755906</td>\n",
|
|||
|
" <td id=\"T_71e02_row4_col2\" class=\"data row4 col2\" >0.838735</td>\n",
|
|||
|
" <td id=\"T_71e02_row4_col3\" class=\"data row4 col3\" >0.623260</td>\n",
|
|||
|
" <td id=\"T_71e02_row4_col4\" class=\"data row4 col4\" >0.628905</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_71e02_level0_row5\" class=\"row_heading level0 row5\" >decision_tree</th>\n",
|
|||
|
" <td id=\"T_71e02_row5_col0\" class=\"data row5 col0\" >0.826816</td>\n",
|
|||
|
" <td id=\"T_71e02_row5_col1\" class=\"data row5 col1\" >0.752000</td>\n",
|
|||
|
" <td id=\"T_71e02_row5_col2\" class=\"data row5 col2\" >0.794137</td>\n",
|
|||
|
" <td id=\"T_71e02_row5_col3\" class=\"data row5 col3\" >0.621151</td>\n",
|
|||
|
" <td id=\"T_71e02_row5_col4\" class=\"data row5 col4\" >0.629142</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_71e02_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
|
|||
|
" <td id=\"T_71e02_row6_col0\" class=\"data row6 col0\" >0.703911</td>\n",
|
|||
|
" <td id=\"T_71e02_row6_col1\" class=\"data row6 col1\" >0.697143</td>\n",
|
|||
|
" <td id=\"T_71e02_row6_col2\" class=\"data row6 col2\" >0.785903</td>\n",
|
|||
|
" <td id=\"T_71e02_row6_col3\" class=\"data row6 col3\" >0.431814</td>\n",
|
|||
|
" <td id=\"T_71e02_row6_col4\" class=\"data row6 col4\" >0.470403</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_71e02_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
|
|||
|
" <td id=\"T_71e02_row7_col0\" class=\"data row7 col0\" >0.681564</td>\n",
|
|||
|
" <td id=\"T_71e02_row7_col1\" class=\"data row7 col1\" >0.344828</td>\n",
|
|||
|
" <td id=\"T_71e02_row7_col2\" class=\"data row7 col2\" >0.712714</td>\n",
|
|||
|
" <td id=\"T_71e02_row7_col3\" class=\"data row7 col3\" >0.220490</td>\n",
|
|||
|
" <td id=\"T_71e02_row7_col4\" class=\"data row7 col4\" >0.307678</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x2827fc885f0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
|||
|
" [\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]\n",
|
|||
|
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" ],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'random_forest'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
|
|||
|
"\n",
|
|||
|
"display(best_model)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Вывод данных с ошибкой предсказания для оценки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'Error items count: 29'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Survived</th>\n",
|
|||
|
" <th>Predicted</th>\n",
|
|||
|
" <th>Pclass</th>\n",
|
|||
|
" <th>Name</th>\n",
|
|||
|
" <th>Sex</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>SibSp</th>\n",
|
|||
|
" <th>Parch</th>\n",
|
|||
|
" <th>Ticket</th>\n",
|
|||
|
" <th>Fare</th>\n",
|
|||
|
" <th>Cabin</th>\n",
|
|||
|
" <th>Embarked</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>PassengerId</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>26</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Asplund, Mrs. Carl Oscar (Selma Augusta Emilia...</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>38.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>5</td>\n",
|
|||
|
" <td>347077</td>\n",
|
|||
|
" <td>31.3875</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>72</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Goodwin, Miss. Lillian Amy</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>16.0</td>\n",
|
|||
|
" <td>5</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>CA 2144</td>\n",
|
|||
|
" <td>46.9000</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>103</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>White, Mr. Richard Frasar</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>21.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>35281</td>\n",
|
|||
|
" <td>77.2875</td>\n",
|
|||
|
" <td>D26</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>108</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Moss, Mr. Albert Johan</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>312991</td>\n",
|
|||
|
" <td>7.7750</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>128</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Madsen, Mr. Fridtjof Arne</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>24.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>C 17369</td>\n",
|
|||
|
" <td>7.1417</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>193</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Andersen-Jensen, Miss. Carla Christine Nielsine</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>19.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>350046</td>\n",
|
|||
|
" <td>7.8542</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>241</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Zabour, Miss. Thamine</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2665</td>\n",
|
|||
|
" <td>14.4542</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>C</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>272</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Tornquist, Mr. William Henry</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>25.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>LINE</td>\n",
|
|||
|
" <td>0.0000</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>293</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>Levy, Mr. Rene Jacques</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>36.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>SC/Paris 2163</td>\n",
|
|||
|
" <td>12.8750</td>\n",
|
|||
|
" <td>D</td>\n",
|
|||
|
" <td>C</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>352</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Williams-Lambert, Mr. Fletcher Fellows</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>113510</td>\n",
|
|||
|
" <td>35.0000</td>\n",
|
|||
|
" <td>C128</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>358</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>Funk, Miss. Annie Clemmer</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>38.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>237671</td>\n",
|
|||
|
" <td>13.0000</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>378</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Widener, Mr. Harry Elkins</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>27.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>113503</td>\n",
|
|||
|
" <td>211.5000</td>\n",
|
|||
|
" <td>C82</td>\n",
|
|||
|
" <td>C</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>445</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Johannesen-Bratthammer, Mr. Bernt</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>65306</td>\n",
|
|||
|
" <td>8.1125</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>450</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Peuchen, Major. Arthur Godfrey</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>52.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>113786</td>\n",
|
|||
|
" <td>30.5000</td>\n",
|
|||
|
" <td>C104</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>508</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Bradley, Mr. George (\"George Arthur Brayton\")</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>111427</td>\n",
|
|||
|
" <td>26.5500</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>511</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Daly, Mr. Eugene Patrick</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>29.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>382651</td>\n",
|
|||
|
" <td>7.7500</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>Q</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>570</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Jonsson, Mr. Carl</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>32.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>350417</td>\n",
|
|||
|
" <td>7.8542</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>579</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Caram, Mrs. Joseph (Maria Elias)</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2689</td>\n",
|
|||
|
" <td>14.4583</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>C</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>584</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Ross, Mr. John Hugo</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>36.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>13049</td>\n",
|
|||
|
" <td>40.1250</td>\n",
|
|||
|
" <td>A10</td>\n",
|
|||
|
" <td>C</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>588</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Frolicher-Stehli, Mr. Maxmillian</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>60.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>13567</td>\n",
|
|||
|
" <td>79.2000</td>\n",
|
|||
|
" <td>B41</td>\n",
|
|||
|
" <td>C</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>618</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Lobb, Mrs. William Arthur (Cordelia K Stanlick)</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>26.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>A/5. 3336</td>\n",
|
|||
|
" <td>16.1000</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>658</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Bourke, Mrs. John (Catherine)</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>32.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>364849</td>\n",
|
|||
|
" <td>15.5000</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>Q</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>661</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Frauenthal, Dr. Henry William</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>50.0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>PC 17611</td>\n",
|
|||
|
" <td>133.6500</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>674</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>Wilhelms, Mr. Charles</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>31.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>244270</td>\n",
|
|||
|
" <td>13.0000</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>745</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Stranden, Mr. Juho</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>31.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>STON/O 2. 3101288</td>\n",
|
|||
|
" <td>7.9250</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>773</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>Mack, Mrs. (Mary)</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>57.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>S.O./P.P. 3</td>\n",
|
|||
|
" <td>10.5000</td>\n",
|
|||
|
" <td>E77</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>807</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Andrews, Mr. Thomas Jr</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>39.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>112050</td>\n",
|
|||
|
" <td>0.0000</td>\n",
|
|||
|
" <td>A36</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>814</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>Andersson, Miss. Ebba Iris Alfrida</td>\n",
|
|||
|
" <td>female</td>\n",
|
|||
|
" <td>6.0</td>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>347082</td>\n",
|
|||
|
" <td>31.2750</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>829</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>McCormack, Mr. Thomas Joseph</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>367228</td>\n",
|
|||
|
" <td>7.7500</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>Q</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Survived Predicted Pclass \\\n",
|
|||
|
"PassengerId \n",
|
|||
|
"26 1 0 3 \n",
|
|||
|
"72 0 1 3 \n",
|
|||
|
"103 0 1 1 \n",
|
|||
|
"108 1 0 3 \n",
|
|||
|
"128 1 0 3 \n",
|
|||
|
"193 1 0 3 \n",
|
|||
|
"241 0 1 3 \n",
|
|||
|
"272 1 0 3 \n",
|
|||
|
"293 0 1 2 \n",
|
|||
|
"352 0 1 1 \n",
|
|||
|
"358 0 1 2 \n",
|
|||
|
"378 0 1 1 \n",
|
|||
|
"445 1 0 3 \n",
|
|||
|
"450 1 0 1 \n",
|
|||
|
"508 1 0 1 \n",
|
|||
|
"511 1 0 3 \n",
|
|||
|
"570 1 0 3 \n",
|
|||
|
"579 0 1 3 \n",
|
|||
|
"584 0 1 1 \n",
|
|||
|
"588 1 0 1 \n",
|
|||
|
"618 0 1 3 \n",
|
|||
|
"658 0 1 3 \n",
|
|||
|
"661 1 0 1 \n",
|
|||
|
"674 1 0 2 \n",
|
|||
|
"745 1 0 3 \n",
|
|||
|
"773 0 1 2 \n",
|
|||
|
"807 0 1 1 \n",
|
|||
|
"814 0 1 3 \n",
|
|||
|
"829 1 0 3 \n",
|
|||
|
"\n",
|
|||
|
" Name Sex Age \\\n",
|
|||
|
"PassengerId \n",
|
|||
|
"26 Asplund, Mrs. Carl Oscar (Selma Augusta Emilia... female 38.0 \n",
|
|||
|
"72 Goodwin, Miss. Lillian Amy female 16.0 \n",
|
|||
|
"103 White, Mr. Richard Frasar male 21.0 \n",
|
|||
|
"108 Moss, Mr. Albert Johan male NaN \n",
|
|||
|
"128 Madsen, Mr. Fridtjof Arne male 24.0 \n",
|
|||
|
"193 Andersen-Jensen, Miss. Carla Christine Nielsine female 19.0 \n",
|
|||
|
"241 Zabour, Miss. Thamine female NaN \n",
|
|||
|
"272 Tornquist, Mr. William Henry male 25.0 \n",
|
|||
|
"293 Levy, Mr. Rene Jacques male 36.0 \n",
|
|||
|
"352 Williams-Lambert, Mr. Fletcher Fellows male NaN \n",
|
|||
|
"358 Funk, Miss. Annie Clemmer female 38.0 \n",
|
|||
|
"378 Widener, Mr. Harry Elkins male 27.0 \n",
|
|||
|
"445 Johannesen-Bratthammer, Mr. Bernt male NaN \n",
|
|||
|
"450 Peuchen, Major. Arthur Godfrey male 52.0 \n",
|
|||
|
"508 Bradley, Mr. George (\"George Arthur Brayton\") male NaN \n",
|
|||
|
"511 Daly, Mr. Eugene Patrick male 29.0 \n",
|
|||
|
"570 Jonsson, Mr. Carl male 32.0 \n",
|
|||
|
"579 Caram, Mrs. Joseph (Maria Elias) female NaN \n",
|
|||
|
"584 Ross, Mr. John Hugo male 36.0 \n",
|
|||
|
"588 Frolicher-Stehli, Mr. Maxmillian male 60.0 \n",
|
|||
|
"618 Lobb, Mrs. William Arthur (Cordelia K Stanlick) female 26.0 \n",
|
|||
|
"658 Bourke, Mrs. John (Catherine) female 32.0 \n",
|
|||
|
"661 Frauenthal, Dr. Henry William male 50.0 \n",
|
|||
|
"674 Wilhelms, Mr. Charles male 31.0 \n",
|
|||
|
"745 Stranden, Mr. Juho male 31.0 \n",
|
|||
|
"773 Mack, Mrs. (Mary) female 57.0 \n",
|
|||
|
"807 Andrews, Mr. Thomas Jr male 39.0 \n",
|
|||
|
"814 Andersson, Miss. Ebba Iris Alfrida female 6.0 \n",
|
|||
|
"829 McCormack, Mr. Thomas Joseph male NaN \n",
|
|||
|
"\n",
|
|||
|
" SibSp Parch Ticket Fare Cabin Embarked \n",
|
|||
|
"PassengerId \n",
|
|||
|
"26 1 5 347077 31.3875 NaN S \n",
|
|||
|
"72 5 2 CA 2144 46.9000 NaN S \n",
|
|||
|
"103 0 1 35281 77.2875 D26 S \n",
|
|||
|
"108 0 0 312991 7.7750 NaN S \n",
|
|||
|
"128 0 0 C 17369 7.1417 NaN S \n",
|
|||
|
"193 1 0 350046 7.8542 NaN S \n",
|
|||
|
"241 1 0 2665 14.4542 NaN C \n",
|
|||
|
"272 0 0 LINE 0.0000 NaN S \n",
|
|||
|
"293 0 0 SC/Paris 2163 12.8750 D C \n",
|
|||
|
"352 0 0 113510 35.0000 C128 S \n",
|
|||
|
"358 0 0 237671 13.0000 NaN S \n",
|
|||
|
"378 0 2 113503 211.5000 C82 C \n",
|
|||
|
"445 0 0 65306 8.1125 NaN S \n",
|
|||
|
"450 0 0 113786 30.5000 C104 S \n",
|
|||
|
"508 0 0 111427 26.5500 NaN S \n",
|
|||
|
"511 0 0 382651 7.7500 NaN Q \n",
|
|||
|
"570 0 0 350417 7.8542 NaN S \n",
|
|||
|
"579 1 0 2689 14.4583 NaN C \n",
|
|||
|
"584 0 0 13049 40.1250 A10 C \n",
|
|||
|
"588 1 1 13567 79.2000 B41 C \n",
|
|||
|
"618 1 0 A/5. 3336 16.1000 NaN S \n",
|
|||
|
"658 1 1 364849 15.5000 NaN Q \n",
|
|||
|
"661 2 0 PC 17611 133.6500 NaN S \n",
|
|||
|
"674 0 0 244270 13.0000 NaN S \n",
|
|||
|
"745 0 0 STON/O 2. 3101288 7.9250 NaN S \n",
|
|||
|
"773 0 0 S.O./P.P. 3 10.5000 E77 S \n",
|
|||
|
"807 0 0 112050 0.0000 A36 S \n",
|
|||
|
"814 4 2 347082 31.2750 NaN S \n",
|
|||
|
"829 0 0 367228 7.7500 NaN Q "
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 11,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"preprocessing_result = pipeline_end.transform(X_test)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"y_pred = class_models[best_model][\"preds\"]\n",
|
|||
|
"\n",
|
|||
|
"error_index = y_test[y_test[\"Survived\"] != y_pred].index.tolist()\n",
|
|||
|
"display(f\"Error items count: {len(error_index)}\")\n",
|
|||
|
"\n",
|
|||
|
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
|
|||
|
"error_df = X_test.loc[error_index].copy()\n",
|
|||
|
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
|
|||
|
"error_df.sort_index()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Пример использования обученной модели (конвейера) для предсказания"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Survived</th>\n",
|
|||
|
" <th>Pclass</th>\n",
|
|||
|
" <th>Name</th>\n",
|
|||
|
" <th>Sex</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>SibSp</th>\n",
|
|||
|
" <th>Parch</th>\n",
|
|||
|
" <th>Ticket</th>\n",
|
|||
|
" <th>Fare</th>\n",
|
|||
|
" <th>Cabin</th>\n",
|
|||
|
" <th>Embarked</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>450</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Peuchen, Major. Arthur Godfrey</td>\n",
|
|||
|
" <td>male</td>\n",
|
|||
|
" <td>52.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>113786</td>\n",
|
|||
|
" <td>30.5</td>\n",
|
|||
|
" <td>C104</td>\n",
|
|||
|
" <td>S</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Survived Pclass Name Sex Age SibSp Parch \\\n",
|
|||
|
"450 1 1 Peuchen, Major. Arthur Godfrey male 52.0 0 0 \n",
|
|||
|
"\n",
|
|||
|
" Ticket Fare Cabin Embarked \n",
|
|||
|
"450 113786 30.5 C104 S "
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Cabin_type_B</th>\n",
|
|||
|
" <th>Cabin_type_C</th>\n",
|
|||
|
" <th>Cabin_type_D</th>\n",
|
|||
|
" <th>Cabin_type_E</th>\n",
|
|||
|
" <th>Cabin_type_F</th>\n",
|
|||
|
" <th>Cabin_type_G</th>\n",
|
|||
|
" <th>Cabin_type_T</th>\n",
|
|||
|
" <th>Cabin_type_u</th>\n",
|
|||
|
" <th>Is_married</th>\n",
|
|||
|
" <th>Pclass</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>SibSp</th>\n",
|
|||
|
" <th>Sex_male</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>450</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>-1.580088</td>\n",
|
|||
|
" <td>1.749939</td>\n",
|
|||
|
" <td>-0.473465</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Cabin_type_B Cabin_type_C Cabin_type_D Cabin_type_E Cabin_type_F \\\n",
|
|||
|
"450 0.0 1.0 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Cabin_type_G Cabin_type_T Cabin_type_u Is_married Pclass Age \\\n",
|
|||
|
"450 0.0 0.0 0.0 0.0 -1.580088 1.749939 \n",
|
|||
|
"\n",
|
|||
|
" SibSp Sex_male \n",
|
|||
|
"450 -0.473465 1.0 "
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'predicted: 0 (proba: [0.91145747 0.08854253])'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'real: 1'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"model = class_models[best_model][\"pipeline\"]\n",
|
|||
|
"\n",
|
|||
|
"example_id = 450\n",
|
|||
|
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
|
|||
|
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
|
|||
|
"display(test)\n",
|
|||
|
"display(test_preprocessed)\n",
|
|||
|
"result_proba = model.predict_proba(test)[0]\n",
|
|||
|
"result = model.predict(test)[0]\n",
|
|||
|
"real = int(y_test.loc[example_id].values[0])\n",
|
|||
|
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
|
|||
|
"display(f\"real: {real}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Подбор гиперпараметров методом поиска по сетке\n",
|
|||
|
"\n",
|
|||
|
"https://www.kaggle.com/code/sociopath00/random-forest-using-gridsearchcv\n",
|
|||
|
"\n",
|
|||
|
"https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 13,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\user\\Projects\\python\\ckmai\\.venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
|
|||
|
" _data = np.array(data, dtype=dtype, copy=copy,\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"{'model__criterion': 'gini',\n",
|
|||
|
" 'model__max_depth': 7,\n",
|
|||
|
" 'model__max_features': 'sqrt',\n",
|
|||
|
" 'model__n_estimators': 30}"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 13,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.model_selection import GridSearchCV\n",
|
|||
|
"\n",
|
|||
|
"optimized_model_type = \"random_forest\"\n",
|
|||
|
"\n",
|
|||
|
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
|
|||
|
"\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
|
|||
|
" \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
|
|||
|
" \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10],\n",
|
|||
|
" \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"gs_optomizer = GridSearchCV(\n",
|
|||
|
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
|
|||
|
")\n",
|
|||
|
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
|
|||
|
"gs_optomizer.best_params_"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обучение модели с новыми гиперпараметрами"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"pipeline = gs_optomizer.best_estimator_.fit(X_train, y_train.values.ravel())\n",
|
|||
|
"\n",
|
|||
|
"result = run_classification(pipeline, X_train, X_test, y_train, y_test)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Формирование данных для оценки старой и новой версии модели"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
|
|||
|
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
|||
|
" data=class_models[optimized_model_type]\n",
|
|||
|
")\n",
|
|||
|
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
|||
|
" data=result\n",
|
|||
|
")\n",
|
|||
|
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
|
|||
|
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Оценка параметров старой и новой модели"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 16,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_52c7b_row0_col0, #T_52c7b_row0_col1, #T_52c7b_row0_col2, #T_52c7b_row1_col3 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_52c7b_row0_col3, #T_52c7b_row1_col0, #T_52c7b_row1_col1, #T_52c7b_row1_col2 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_52c7b_row0_col4, #T_52c7b_row0_col6, #T_52c7b_row1_col7 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_52c7b_row0_col5, #T_52c7b_row1_col5 {\n",
|
|||
|
" background-color: #0d0887;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_52c7b_row0_col7, #T_52c7b_row1_col4, #T_52c7b_row1_col6 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_52c7b\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_52c7b_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
|||
|
" <th id=\"T_52c7b_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
|||
|
" <th id=\"T_52c7b_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
|||
|
" <th id=\"T_52c7b_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
|||
|
" <th id=\"T_52c7b_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
|||
|
" <th id=\"T_52c7b_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_52c7b_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
|||
|
" <th id=\"T_52c7b_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"index_name level0\" >Name</th>\n",
|
|||
|
" <th class=\"blank col0\" > </th>\n",
|
|||
|
" <th class=\"blank col1\" > </th>\n",
|
|||
|
" <th class=\"blank col2\" > </th>\n",
|
|||
|
" <th class=\"blank col3\" > </th>\n",
|
|||
|
" <th class=\"blank col4\" > </th>\n",
|
|||
|
" <th class=\"blank col5\" > </th>\n",
|
|||
|
" <th class=\"blank col6\" > </th>\n",
|
|||
|
" <th class=\"blank col7\" > </th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_52c7b_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
|||
|
" <td id=\"T_52c7b_row0_col0\" class=\"data row0 col0\" >0.894340</td>\n",
|
|||
|
" <td id=\"T_52c7b_row0_col1\" class=\"data row0 col1\" >0.794118</td>\n",
|
|||
|
" <td id=\"T_52c7b_row0_col2\" class=\"data row0 col2\" >0.868132</td>\n",
|
|||
|
" <td id=\"T_52c7b_row0_col3\" class=\"data row0 col3\" >0.782609</td>\n",
|
|||
|
" <td id=\"T_52c7b_row0_col4\" class=\"data row0 col4\" >0.910112</td>\n",
|
|||
|
" <td id=\"T_52c7b_row0_col5\" class=\"data row0 col5\" >0.837989</td>\n",
|
|||
|
" <td id=\"T_52c7b_row0_col6\" class=\"data row0 col6\" >0.881041</td>\n",
|
|||
|
" <td id=\"T_52c7b_row0_col7\" class=\"data row0 col7\" >0.788321</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_52c7b_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
|||
|
" <td id=\"T_52c7b_row1_col0\" class=\"data row1 col0\" >0.800699</td>\n",
|
|||
|
" <td id=\"T_52c7b_row1_col1\" class=\"data row1 col1\" >0.777778</td>\n",
|
|||
|
" <td id=\"T_52c7b_row1_col2\" class=\"data row1 col2\" >0.838828</td>\n",
|
|||
|
" <td id=\"T_52c7b_row1_col3\" class=\"data row1 col3\" >0.811594</td>\n",
|
|||
|
" <td id=\"T_52c7b_row1_col4\" class=\"data row1 col4\" >0.858146</td>\n",
|
|||
|
" <td id=\"T_52c7b_row1_col5\" class=\"data row1 col5\" >0.837989</td>\n",
|
|||
|
" <td id=\"T_52c7b_row1_col6\" class=\"data row1 col6\" >0.819320</td>\n",
|
|||
|
" <td id=\"T_52c7b_row1_col7\" class=\"data row1 col7\" >0.794326</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x28221285d30>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 16,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"optimized_metrics[\n",
|
|||
|
" [\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" \"Accuracy_train\",\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_train\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"].style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 17,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_3ee0b_row0_col0, #T_3ee0b_row1_col0 {\n",
|
|||
|
" background-color: #440154;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_3ee0b_row0_col1 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_3ee0b_row0_col2, #T_3ee0b_row0_col3, #T_3ee0b_row0_col4 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_3ee0b_row1_col1 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_3ee0b_row1_col2, #T_3ee0b_row1_col3, #T_3ee0b_row1_col4 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_3ee0b\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_3ee0b_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_3ee0b_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
|||
|
" <th id=\"T_3ee0b_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
|||
|
" <th id=\"T_3ee0b_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
|||
|
" <th id=\"T_3ee0b_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"index_name level0\" >Name</th>\n",
|
|||
|
" <th class=\"blank col0\" > </th>\n",
|
|||
|
" <th class=\"blank col1\" > </th>\n",
|
|||
|
" <th class=\"blank col2\" > </th>\n",
|
|||
|
" <th class=\"blank col3\" > </th>\n",
|
|||
|
" <th class=\"blank col4\" > </th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_3ee0b_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
|||
|
" <td id=\"T_3ee0b_row0_col0\" class=\"data row0 col0\" >0.837989</td>\n",
|
|||
|
" <td id=\"T_3ee0b_row0_col1\" class=\"data row0 col1\" >0.788321</td>\n",
|
|||
|
" <td id=\"T_3ee0b_row0_col2\" class=\"data row0 col2\" >0.858893</td>\n",
|
|||
|
" <td id=\"T_3ee0b_row0_col3\" class=\"data row0 col3\" >0.657111</td>\n",
|
|||
|
" <td id=\"T_3ee0b_row0_col4\" class=\"data row0 col4\" >0.657157</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_3ee0b_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
|||
|
" <td id=\"T_3ee0b_row1_col0\" class=\"data row1 col0\" >0.837989</td>\n",
|
|||
|
" <td id=\"T_3ee0b_row1_col1\" class=\"data row1 col1\" >0.794326</td>\n",
|
|||
|
" <td id=\"T_3ee0b_row1_col2\" class=\"data row1 col2\" >0.866140</td>\n",
|
|||
|
" <td id=\"T_3ee0b_row1_col3\" class=\"data row1 col3\" >0.660785</td>\n",
|
|||
|
" <td id=\"T_3ee0b_row1_col4\" class=\"data row1 col4\" >0.661193</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x28220ffbc20>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 17,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"optimized_metrics[\n",
|
|||
|
" [\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"].style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" ],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 18,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA20AAAGjCAYAAAC/j/0nAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/GU6VOAAAACXBIWXMAAA9hAAAPYQGoP6dpAABW9ElEQVR4nO3dd3hUVf7H8c+EVEgyAYQUCBCkBWkCihHFBUOxrCARlY0rVXcVkaYIClIVRV35oQiuIkWJKCCsDVhFQVBAiYKoiICUQAoqJiFACjPz+4NlZKRlyE1muPf9ep77aGbunDkTYz4593zPuTaXy+USAAAAAMAvBfi6AwAAAACAs2PQBgAAAAB+jEEbAAAAAPgxBm0AAAAA4McYtAEAAACAH2PQBgAAAAB+jEEbAAAAAPixQF93AABQsQoLC1VcXGxYe8HBwQoNDTWsPQAAvGGFXGPQBgAWUlhYqIS64co+6DCszZiYGO3evdvvAg4AYH5WyTUGbQBgIcXFxco+6NDe9HqKjCh7hXz+Yafqttmj4uJivwo3AIA1WCXXGLQBgAWFR9gUHmErcztOlb0NAADKyuy5xqANACzI4XLK4TKmHQAAfM3sucbukQAAAADgx5hpAwALcsolp8p+SdKINgAAKCuz5xqDNgCwIKecMqIAxJhWAAAoG7PnGuWRAAAAAODHmGkDAAtyuFxyuMpeAmJEGwAAlJXZc41BGwBYkNlr/wEA1mL2XKM8EgAAAAD8GDNtAGBBTrnkMPEVSQCAtZg915hpAwAAAAA/xkwbAFiQ2Wv/AQDWYvZcY9AGABZk9l22AADWYvZcozwSAAAAAPwYgzYAsCCngQcAAL7mq1w7fPiwhg4dqrp16yosLExXX321vvrqK/fzLpdLjz/+uGJjYxUWFqbk5GTt2LHD68/HoA0ALMjxv122jDgAAPA1X+XawIED9dFHH+n111/X1q1b1aVLFyUnJ+vAgQOSpKlTp2r69OmaNWuWNm7cqCpVqqhr164qLCz06n0YtAEAAACAl44dO6YlS5Zo6tSp6tChgxo0aKDx48erQYMGmjlzplwul6ZNm6YxY8aoe/fuatGihebPn6/MzEwtW7bMq/di0AYAFuRwGXcAAOBrRudafn6+x1FUVHTaex4/flwOh0OhoaEej4eFhWndunXavXu3srOzlZyc7H7ObrerXbt2Wr9+vVefj0EbAFgQa9oAAGZidK7Fx8fLbre7jylTppz2nhEREUpKStKkSZOUmZkph8OhN954Q+vXr1dWVpays7MlSdHR0R6vi46Odj9XWmz5DwAAAACnyMjIUGRkpPvrkJCQM573+uuvq3///qpVq5YqVaqk1q1bq3fv3kpPTze0P8y0AYAFOWWTw4DDKZuvPwoAAIbnWmRkpMdxtkHbpZdeqjVr1qigoEAZGRn68ssvVVJSovr16ysmJkaSlJOT4/GanJwc93OlxaANACzI6TLuAADA13yda1WqVFFsbKx+//13rVy5Ut27d1dCQoJiYmK0atUq93n5+fnauHGjkpKSvGqf8kgAAAAAuAArV66Uy+VS48aNtXPnTj388MNq0qSJ+vXrJ5vNpqFDh2ry5Mlq2LChEhISNHbsWMXFxalHjx5evQ+DNgCwoJNlIEa0AwCAr/kq1/Ly8jR69Gjt379f1apVU0pKip544gkFBQVJkkaOHKkjR47o3nvvVW5urq655hqtWLHitB0nz4fySABAhTl8+LCGDh2qunXrKiwsTFdffbW++uor9/Mul0uPP/64YmNjFRYWpuTkZO3YscOHPQYA4Oxuv/127dq1S0VFRcrKytKLL74ou93uft5ms2nixInKzs5WYWGhPv74YzVq1Mjr92HQBgAWZMRi7Qu5qjlw4EB99NFHev3117V161Z16dJFycnJOnDggCRp6tSpmj59umbNmqWNGzeqSpUq6tq1qwoLC8vj2wAAMAlf5VpFYdAGABbkdNkMO6TS3YT02LFjWrJkiaZOnaoOHTqoQYMGGj9+vBo0aKCZM2fK5XJp2rRpGjNmjLp3764WLVpo/vz5yszM1LJlyyr4OwQAuJgYnWv+hkEbAKDMSnMT0uPHj8vhcJxWxx8WFqZ169Zp9+7dys7OVnJysvs5u92udu3aaf369eX+GQAA8FdsRAIAFmT0gu3S3IQ0IiJCSUlJmjRpkhITExUdHa0333xT69evV4MGDZSdnS1Jio6O9nhddHS0+zkAAM7E7BtsMWgDAAtyKEAOA4otHP/758mbj57P66+/rv79+6tWrVqqVKmSWrdurd69eys9Pb3MfQEAWJfRueZvKI8EAFSYSy+9VGvWrFFBQYEyMjL05ZdfqqSkRPXr11dMTIwkKScnx+M1OTk57ucAALAiBm0AYEEugxZruy5wwXaVKlUUGxur33//XStXrlT37t2VkJCgmJgYrVq1yn1efn6+Nm7cqKSkJKM+OgDAhHyda+WN8kgAsCBf1f6vXLlSLpdLjRs31s6dO/Xwww+rSZMm6tevn2w2m4YOHarJkyerYcOGSkhI0NixYxUXF6cePXqUua8AAPNiTRsAAAbJy8vT6NGjtX//flWrVk0pKSl64oknFBQUJEkaOXKkjhw5onvvvVe5ubm65pprtGLFitN2nAQAwEpsLpfL5etOAAAqRn5+vux2u5Z/m6AqEWWvkD9y2KkbWuxWXl5eqTYiAQDASFbJNda0AQAAAIAfozwSACzIKZucBly3c4piDQCA75k91xi0AYAFmX3BNgDAWsyea5RHAgAAAIAfY6YNACzI4QqQw1X263YO9rICAPgBs+cagzYAsKATtf9lLwExog0AAMrK7LlGeSQAAAAA+DFm2gDAgpwKkMPEu2wBAKzF7LnGoA0ALMjstf8AAGsxe65RHgkAAAAAfoyZNgCwIKcCTH0TUgCAtZg91xi0AYAFOVw2OVwG3ITUgDYAACgrs+ca5ZEAAAAA4MeYaQMAC3IYtMuWw0/LSAAA1mL2XGOmDQAAAAD8GDNtAGBBTleAnAZsjez0062RAQDWYvZcY9AGABZk9jISAIC1mD3XKI8EAAAAAD/GTBsAWJBTxmxr7Cx7VwAAKDOz5xqDNgCwIONuQkrBBgDA98yea/7ZKwAAAACAJGbaAMCSHK4AOQzYZcuINgAAKCuz5xqDNgCwIKdscsqI2v+ytwEAQFmZPdf8cygJAAAAAJDETBsAWJLZy0gAANZi9lzzz14BAAAAACQx0wYAluRQgBwGXLczog0AAMrK7LnGoM1HnE6nMjMzFRERIZvNPxc8AvAvLpdLhw8fVlxcnAICyhYqTpdNTiNuQmpAGzAHcg2At8i10mPQ5iOZmZmKj4/3dTcAXIQyMjJUu3ZtX3cD8ECuAbhQ5Nr5MWjzkYiICEnS3q/rKTLcP6dh4Tu3Nmru6y7ADx1XidbpQ/fvj7JwGlRG4vTTMhJUPHIN55LSJsnXXYAfOu4q0WcFb1/UueZwODR+/Hi98cYbys7OVlxcnPr27asxY8a4qw5cLpfGjRunV155Rbm5uWrfvr1mzpyphg0blvp9GLT5yMn/iJHhAYqMINzgKdAW5OsuwB+5TvzDiNIzpytATgN2yDKiDZgDuYZzCbQF+7oL8GMXc649/fTTmjlzpubNm6fLLrtMmzZtUr9+/WS32/Xggw9KkqZOnarp06dr3rx5SkhI0NixY9W1a1f98MMPCg0NLdX7MGgDAAAAgFPk5+d7fB0SEqKQkJDTzvviiy/UvXt33XTTTZKkevXq6c0339SXX34p6cQs27Rp0zRmzBh1795dkjR//nxFR0dr2bJluvPOO0vVHy6FAYAFOWQz7AAAwNeMzrX4+HjZ7Xb3MWXKlDO+79VXX61Vq1bpp59+kiRt2bJF69at0w033CBJ2r17t7Kzs5WcnOx+jd1uV7t27bR+/fpSfz5m2gDAgiiPBACYidG5lpGRocjISPfjZ5plk6RRo0YpPz9fTZo0UaVKleRwOPTEE08oNTVVkpSdnS1Jio6O9nhddHS0+7nSYNAGAAAAAKeIjIz0GLSdzdtvv60FCxYoLS1Nl112mTZv3qyhQ4c
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1000x400 with 4 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"for index in range(0, len(optimized_metrics)):\n",
|
|||
|
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[\"Died\", \"Sirvived\"]\n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
"\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Регрессия"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Загрузка данных"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 19,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>T</th>\n",
|
|||
|
" <th>Al2O3</th>\n",
|
|||
|
" <th>TiO2</th>\n",
|
|||
|
" <th>Density</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>20</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.06250</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>25</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.05979</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>35</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.05404</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" T Al2O3 TiO2 Density\n",
|
|||
|
"0 20 0.0 0.0 1.06250\n",
|
|||
|
"1 25 0.0 0.0 1.05979\n",
|
|||
|
"2 35 0.0 0.0 1.05404"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>T</th>\n",
|
|||
|
" <th>Al2O3</th>\n",
|
|||
|
" <th>TiO2</th>\n",
|
|||
|
" <th>Density</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>30</td>\n",
|
|||
|
" <td>0.00</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.05696</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>55</td>\n",
|
|||
|
" <td>0.00</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.04158</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>25</td>\n",
|
|||
|
" <td>0.05</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.08438</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" T Al2O3 TiO2 Density\n",
|
|||
|
"0 30 0.00 0.0 1.05696\n",
|
|||
|
"1 55 0.00 0.0 1.04158\n",
|
|||
|
"2 25 0.05 0.0 1.08438"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"\n",
|
|||
|
"density_train = pd.read_csv(\"data/density/density_train.csv\", sep=\";\", decimal=\",\")\n",
|
|||
|
"density_test = pd.read_csv(\"data/density/density_test.csv\", sep=\";\", decimal=\",\")\n",
|
|||
|
"\n",
|
|||
|
"display(density_train.head(3))\n",
|
|||
|
"display(density_test.head(3))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Формирование выборок"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 20,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>T</th>\n",
|
|||
|
" <th>Al2O3</th>\n",
|
|||
|
" <th>TiO2</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>20</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>25</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>35</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" T Al2O3 TiO2\n",
|
|||
|
"0 20 0.0 0.0\n",
|
|||
|
"1 25 0.0 0.0\n",
|
|||
|
"2 35 0.0 0.0"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Density</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>1.06250</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>1.05979</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>1.05404</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Density\n",
|
|||
|
"0 1.06250\n",
|
|||
|
"1 1.05979\n",
|
|||
|
"2 1.05404"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>T</th>\n",
|
|||
|
" <th>Al2O3</th>\n",
|
|||
|
" <th>TiO2</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>30</td>\n",
|
|||
|
" <td>0.00</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>55</td>\n",
|
|||
|
" <td>0.00</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>25</td>\n",
|
|||
|
" <td>0.05</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" T Al2O3 TiO2\n",
|
|||
|
"0 30 0.00 0.0\n",
|
|||
|
"1 55 0.00 0.0\n",
|
|||
|
"2 25 0.05 0.0"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Density</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>1.05696</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>1.04158</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>1.08438</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Density\n",
|
|||
|
"0 1.05696\n",
|
|||
|
"1 1.04158\n",
|
|||
|
"2 1.08438"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"density_y_train = pd.DataFrame(density_train[\"Density\"], columns=[\"Density\"])\n",
|
|||
|
"density_train = density_train.drop([\"Density\"], axis=1)\n",
|
|||
|
"\n",
|
|||
|
"display(density_train.head(3))\n",
|
|||
|
"display(density_y_train.head(3))\n",
|
|||
|
"\n",
|
|||
|
"density_y_test = pd.DataFrame(density_test[\"Density\"], columns=[\"Density\"])\n",
|
|||
|
"density_test = density_test.drop([\"Density\"], axis=1)\n",
|
|||
|
"\n",
|
|||
|
"display(density_test.head(3))\n",
|
|||
|
"display(density_y_test.head(3))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Определение перечня алгоритмов решения задачи аппроксимации (регрессии)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 21,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.pipeline import make_pipeline\n",
|
|||
|
"from sklearn.preprocessing import PolynomialFeatures\n",
|
|||
|
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
|
|||
|
"\n",
|
|||
|
"random_state = 9\n",
|
|||
|
"\n",
|
|||
|
"models = {\n",
|
|||
|
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
|
|||
|
" \"linear_poly\": {\n",
|
|||
|
" \"model\": make_pipeline(\n",
|
|||
|
" PolynomialFeatures(degree=2),\n",
|
|||
|
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" \"linear_interact\": {\n",
|
|||
|
" \"model\": make_pipeline(\n",
|
|||
|
" PolynomialFeatures(interaction_only=True),\n",
|
|||
|
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
|
|||
|
" \"decision_tree\": {\n",
|
|||
|
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
|
|||
|
" },\n",
|
|||
|
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
|
|||
|
" \"random_forest\": {\n",
|
|||
|
" \"model\": ensemble.RandomForestRegressor(\n",
|
|||
|
" max_depth=7, random_state=random_state, n_jobs=-1\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" \"mlp\": {\n",
|
|||
|
" \"model\": neural_network.MLPRegressor(\n",
|
|||
|
" activation=\"tanh\",\n",
|
|||
|
" hidden_layer_sizes=(3,),\n",
|
|||
|
" max_iter=500,\n",
|
|||
|
" early_stopping=True,\n",
|
|||
|
" random_state=random_state,\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
"}"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Обучение и оценка моделей с помощью различных алгоритмов"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 22,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: linear\n",
|
|||
|
"Model: linear_poly\n",
|
|||
|
"Model: linear_interact\n",
|
|||
|
"Model: ridge\n",
|
|||
|
"Model: decision_tree\n",
|
|||
|
"Model: knn\n",
|
|||
|
"Model: random_forest\n",
|
|||
|
"Model: mlp\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from src.utils import run_regression\n",
|
|||
|
"\n",
|
|||
|
"for model_name in models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" X_train = density_train\n",
|
|||
|
" X_test = density_test\n",
|
|||
|
" y_train = density_y_train\n",
|
|||
|
" y_test = density_y_test\n",
|
|||
|
"\n",
|
|||
|
" model = models[model_name][\"model\"]\n",
|
|||
|
" fitted_model = model.fit(\n",
|
|||
|
" X_train.values, density_y_train.values.ravel()\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" models[model_name] = run_regression(fitted_model, X_train, X_test, y_train, y_test)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Вывод результатов оценки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 23,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_63a66_row0_col0, #T_63a66_row0_col1, #T_63a66_row4_col0 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row0_col2, #T_63a66_row7_col3 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row0_col3, #T_63a66_row1_col3, #T_63a66_row2_col3, #T_63a66_row3_col3, #T_63a66_row4_col3, #T_63a66_row7_col2 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row1_col0, #T_63a66_row1_col1, #T_63a66_row2_col0 {\n",
|
|||
|
" background-color: #26828e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row1_col2 {\n",
|
|||
|
" background-color: #5901a5;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row2_col1, #T_63a66_row3_col0 {\n",
|
|||
|
" background-color: #25838e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row2_col2 {\n",
|
|||
|
" background-color: #6400a7;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row3_col1 {\n",
|
|||
|
" background-color: #24868e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row3_col2 {\n",
|
|||
|
" background-color: #7100a8;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row4_col1 {\n",
|
|||
|
" background-color: #24878e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row4_col2 {\n",
|
|||
|
" background-color: #7701a8;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row5_col0 {\n",
|
|||
|
" background-color: #21908d;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row5_col1 {\n",
|
|||
|
" background-color: #21918c;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row5_col2 {\n",
|
|||
|
" background-color: #910ea3;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row5_col3 {\n",
|
|||
|
" background-color: #d8576b;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row6_col0 {\n",
|
|||
|
" background-color: #38b977;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row6_col1 {\n",
|
|||
|
" background-color: #3bbb75;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row6_col2 {\n",
|
|||
|
" background-color: #c5407e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row6_col3 {\n",
|
|||
|
" background-color: #b7318a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_63a66_row7_col0, #T_63a66_row7_col1 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_63a66\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_63a66_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
|
|||
|
" <th id=\"T_63a66_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
|
|||
|
" <th id=\"T_63a66_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
|
|||
|
" <th id=\"T_63a66_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_63a66_level0_row0\" class=\"row_heading level0 row0\" >linear_poly</th>\n",
|
|||
|
" <td id=\"T_63a66_row0_col0\" class=\"data row0 col0\" >0.000319</td>\n",
|
|||
|
" <td id=\"T_63a66_row0_col1\" class=\"data row0 col1\" >0.000362</td>\n",
|
|||
|
" <td id=\"T_63a66_row0_col2\" class=\"data row0 col2\" >0.016643</td>\n",
|
|||
|
" <td id=\"T_63a66_row0_col3\" class=\"data row0 col3\" >0.999965</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_63a66_level0_row1\" class=\"row_heading level0 row1\" >linear_interact</th>\n",
|
|||
|
" <td id=\"T_63a66_row1_col0\" class=\"data row1 col0\" >0.001131</td>\n",
|
|||
|
" <td id=\"T_63a66_row1_col1\" class=\"data row1 col1\" >0.001491</td>\n",
|
|||
|
" <td id=\"T_63a66_row1_col2\" class=\"data row1 col2\" >0.033198</td>\n",
|
|||
|
" <td id=\"T_63a66_row1_col3\" class=\"data row1 col3\" >0.999413</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_63a66_level0_row2\" class=\"row_heading level0 row2\" >linear</th>\n",
|
|||
|
" <td id=\"T_63a66_row2_col0\" class=\"data row2 col0\" >0.002464</td>\n",
|
|||
|
" <td id=\"T_63a66_row2_col1\" class=\"data row2 col1\" >0.003261</td>\n",
|
|||
|
" <td id=\"T_63a66_row2_col2\" class=\"data row2 col2\" >0.049891</td>\n",
|
|||
|
" <td id=\"T_63a66_row2_col3\" class=\"data row2 col3\" >0.997191</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_63a66_level0_row3\" class=\"row_heading level0 row3\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_63a66_row3_col0\" class=\"data row3 col0\" >0.002716</td>\n",
|
|||
|
" <td id=\"T_63a66_row3_col1\" class=\"data row3 col1\" >0.005575</td>\n",
|
|||
|
" <td id=\"T_63a66_row3_col2\" class=\"data row3 col2\" >0.067298</td>\n",
|
|||
|
" <td id=\"T_63a66_row3_col3\" class=\"data row3 col3\" >0.991788</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_63a66_level0_row4\" class=\"row_heading level0 row4\" >decision_tree</th>\n",
|
|||
|
" <td id=\"T_63a66_row4_col0\" class=\"data row4 col0\" >0.000346</td>\n",
|
|||
|
" <td id=\"T_63a66_row4_col1\" class=\"data row4 col1\" >0.006433</td>\n",
|
|||
|
" <td id=\"T_63a66_row4_col2\" class=\"data row4 col2\" >0.076138</td>\n",
|
|||
|
" <td id=\"T_63a66_row4_col3\" class=\"data row4 col3\" >0.989067</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_63a66_level0_row5\" class=\"row_heading level0 row5\" >ridge</th>\n",
|
|||
|
" <td id=\"T_63a66_row5_col0\" class=\"data row5 col0\" >0.013989</td>\n",
|
|||
|
" <td id=\"T_63a66_row5_col1\" class=\"data row5 col1\" >0.015356</td>\n",
|
|||
|
" <td id=\"T_63a66_row5_col2\" class=\"data row5 col2\" >0.116380</td>\n",
|
|||
|
" <td id=\"T_63a66_row5_col3\" class=\"data row5 col3\" >0.937703</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_63a66_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
|
|||
|
" <td id=\"T_63a66_row6_col0\" class=\"data row6 col0\" >0.053108</td>\n",
|
|||
|
" <td id=\"T_63a66_row6_col1\" class=\"data row6 col1\" >0.056776</td>\n",
|
|||
|
" <td id=\"T_63a66_row6_col2\" class=\"data row6 col2\" >0.217611</td>\n",
|
|||
|
" <td id=\"T_63a66_row6_col3\" class=\"data row6 col3\" >0.148414</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_63a66_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
|
|||
|
" <td id=\"T_63a66_row7_col0\" class=\"data row7 col0\" >0.095734</td>\n",
|
|||
|
" <td id=\"T_63a66_row7_col1\" class=\"data row7 col1\" >0.099654</td>\n",
|
|||
|
" <td id=\"T_63a66_row7_col2\" class=\"data row7 col2\" >0.270371</td>\n",
|
|||
|
" <td id=\"T_63a66_row7_col3\" class=\"data row7 col3\" >-1.623554</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x282215a87a0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 23,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
|
|||
|
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
|
|||
|
"]\n",
|
|||
|
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
|
|||
|
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
|
|||
|
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Вывод реального и \"спрогнозированного\" результата для обучающей и тестовой выборок"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Получение лучшей модели"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 24,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'linear_poly'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
|
|||
|
"\n",
|
|||
|
"display(best_model)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Вывод для обучающей выборки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 25,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>T</th>\n",
|
|||
|
" <th>Al2O3</th>\n",
|
|||
|
" <th>TiO2</th>\n",
|
|||
|
" <th>Density</th>\n",
|
|||
|
" <th>DensityPred</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>20</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.06250</td>\n",
|
|||
|
" <td>1.063174</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>25</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.05979</td>\n",
|
|||
|
" <td>1.060117</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>35</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.05404</td>\n",
|
|||
|
" <td>1.053941</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>40</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.05103</td>\n",
|
|||
|
" <td>1.050822</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>45</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.04794</td>\n",
|
|||
|
" <td>1.047683</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" T Al2O3 TiO2 Density DensityPred\n",
|
|||
|
"0 20 0.0 0.0 1.06250 1.063174\n",
|
|||
|
"1 25 0.0 0.0 1.05979 1.060117\n",
|
|||
|
"2 35 0.0 0.0 1.05404 1.053941\n",
|
|||
|
"3 40 0.0 0.0 1.05103 1.050822\n",
|
|||
|
"4 45 0.0 0.0 1.04794 1.047683"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 25,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pd.concat(\n",
|
|||
|
" [\n",
|
|||
|
" density_train,\n",
|
|||
|
" density_y_train,\n",
|
|||
|
" pd.Series(\n",
|
|||
|
" models[best_model][\"train_preds\"],\n",
|
|||
|
" index=density_y_train.index,\n",
|
|||
|
" name=\"DensityPred\",\n",
|
|||
|
" ),\n",
|
|||
|
" ],\n",
|
|||
|
" axis=1,\n",
|
|||
|
").head(5)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Вывод для тестовой выборки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 26,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>T</th>\n",
|
|||
|
" <th>Al2O3</th>\n",
|
|||
|
" <th>TiO2</th>\n",
|
|||
|
" <th>Density</th>\n",
|
|||
|
" <th>DensityPred</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>30</td>\n",
|
|||
|
" <td>0.00</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.05696</td>\n",
|
|||
|
" <td>1.057040</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>55</td>\n",
|
|||
|
" <td>0.00</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.04158</td>\n",
|
|||
|
" <td>1.041341</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>25</td>\n",
|
|||
|
" <td>0.05</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.08438</td>\n",
|
|||
|
" <td>1.084063</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>30</td>\n",
|
|||
|
" <td>0.05</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.08112</td>\n",
|
|||
|
" <td>1.080764</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>35</td>\n",
|
|||
|
" <td>0.05</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.07781</td>\n",
|
|||
|
" <td>1.077444</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" T Al2O3 TiO2 Density DensityPred\n",
|
|||
|
"0 30 0.00 0.0 1.05696 1.057040\n",
|
|||
|
"1 55 0.00 0.0 1.04158 1.041341\n",
|
|||
|
"2 25 0.05 0.0 1.08438 1.084063\n",
|
|||
|
"3 30 0.05 0.0 1.08112 1.080764\n",
|
|||
|
"4 35 0.05 0.0 1.07781 1.077444"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 26,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"pd.concat(\n",
|
|||
|
" [\n",
|
|||
|
" density_test,\n",
|
|||
|
" density_y_test,\n",
|
|||
|
" pd.Series(\n",
|
|||
|
" models[best_model][\"preds\"],\n",
|
|||
|
" index=density_y_test.index,\n",
|
|||
|
" name=\"DensityPred\",\n",
|
|||
|
" ),\n",
|
|||
|
" ],\n",
|
|||
|
" axis=1,\n",
|
|||
|
").head(5)"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": ".venv",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.12.7"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|