pred_analytics/lec4.ipynb
2025-01-13 14:42:39 +04:00

4368 lines
295 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Классификация"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Загрузка набора данных"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Survived</th>\n",
" <th>Pclass</th>\n",
" <th>Name</th>\n",
" <th>Sex</th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Parch</th>\n",
" <th>Ticket</th>\n",
" <th>Fare</th>\n",
" <th>Cabin</th>\n",
" <th>Embarked</th>\n",
" </tr>\n",
" <tr>\n",
" <th>PassengerId</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Braund, Mr. Owen Harris</td>\n",
" <td>male</td>\n",
" <td>22.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>A/5 21171</td>\n",
" <td>7.2500</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n",
" <td>female</td>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>PC 17599</td>\n",
" <td>71.2833</td>\n",
" <td>C85</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Heikkinen, Miss. Laina</td>\n",
" <td>female</td>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>STON/O2. 3101282</td>\n",
" <td>7.9250</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n",
" <td>female</td>\n",
" <td>35.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>113803</td>\n",
" <td>53.1000</td>\n",
" <td>C123</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Allen, Mr. William Henry</td>\n",
" <td>male</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>373450</td>\n",
" <td>8.0500</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>887</th>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>Montvila, Rev. Juozas</td>\n",
" <td>male</td>\n",
" <td>27.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>211536</td>\n",
" <td>13.0000</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>888</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Graham, Miss. Margaret Edith</td>\n",
" <td>female</td>\n",
" <td>19.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>112053</td>\n",
" <td>30.0000</td>\n",
" <td>B42</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>889</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Johnston, Miss. Catherine Helen \"Carrie\"</td>\n",
" <td>female</td>\n",
" <td>NaN</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>W./C. 6607</td>\n",
" <td>23.4500</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>890</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Behr, Mr. Karl Howell</td>\n",
" <td>male</td>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>111369</td>\n",
" <td>30.0000</td>\n",
" <td>C148</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>891</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Dooley, Mr. Patrick</td>\n",
" <td>male</td>\n",
" <td>32.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>370376</td>\n",
" <td>7.7500</td>\n",
" <td>NaN</td>\n",
" <td>Q</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>891 rows × 11 columns</p>\n",
"</div>"
],
"text/plain": [
" Survived Pclass \\\n",
"PassengerId \n",
"1 0 3 \n",
"2 1 1 \n",
"3 1 3 \n",
"4 1 1 \n",
"5 0 3 \n",
"... ... ... \n",
"887 0 2 \n",
"888 1 1 \n",
"889 0 3 \n",
"890 1 1 \n",
"891 0 3 \n",
"\n",
" Name Sex Age \\\n",
"PassengerId \n",
"1 Braund, Mr. Owen Harris male 22.0 \n",
"2 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 \n",
"3 Heikkinen, Miss. Laina female 26.0 \n",
"4 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 \n",
"5 Allen, Mr. William Henry male 35.0 \n",
"... ... ... ... \n",
"887 Montvila, Rev. Juozas male 27.0 \n",
"888 Graham, Miss. Margaret Edith female 19.0 \n",
"889 Johnston, Miss. Catherine Helen \"Carrie\" female NaN \n",
"890 Behr, Mr. Karl Howell male 26.0 \n",
"891 Dooley, Mr. Patrick male 32.0 \n",
"\n",
" SibSp Parch Ticket Fare Cabin Embarked \n",
"PassengerId \n",
"1 1 0 A/5 21171 7.2500 NaN S \n",
"2 1 0 PC 17599 71.2833 C85 C \n",
"3 0 0 STON/O2. 3101282 7.9250 NaN S \n",
"4 1 0 113803 53.1000 C123 S \n",
"5 0 0 373450 8.0500 NaN S \n",
"... ... ... ... ... ... ... \n",
"887 0 0 211536 13.0000 NaN S \n",
"888 0 0 112053 30.0000 B42 S \n",
"889 1 2 W./C. 6607 23.4500 NaN S \n",
"890 0 0 111369 30.0000 C148 C \n",
"891 0 0 370376 7.7500 NaN Q \n",
"\n",
"[891 rows x 11 columns]"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"random_state=9\n",
"\n",
"df = pd.read_csv(\"data/titanic.csv\", index_col=\"PassengerId\")\n",
"\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
"\n",
"Целевой признак -- Survived"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Survived</th>\n",
" <th>Pclass</th>\n",
" <th>Name</th>\n",
" <th>Sex</th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Parch</th>\n",
" <th>Ticket</th>\n",
" <th>Fare</th>\n",
" <th>Cabin</th>\n",
" <th>Embarked</th>\n",
" </tr>\n",
" <tr>\n",
" <th>PassengerId</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>145</th>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>Andrew, Mr. Edgardo Samuel</td>\n",
" <td>male</td>\n",
" <td>18.00</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>231945</td>\n",
" <td>11.5000</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>206</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Strom, Miss. Telma Matilda</td>\n",
" <td>female</td>\n",
" <td>2.00</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>347054</td>\n",
" <td>10.4625</td>\n",
" <td>G6</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>349</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Coutts, Master. William Loch \"William\"</td>\n",
" <td>male</td>\n",
" <td>3.00</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>C.A. 37671</td>\n",
" <td>15.9000</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>329</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Goldsmith, Mrs. Frank John (Emily Alice Brown)</td>\n",
" <td>female</td>\n",
" <td>31.00</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>363291</td>\n",
" <td>20.5250</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>289</th>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>Hosono, Mr. Masabumi</td>\n",
" <td>male</td>\n",
" <td>42.00</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>237798</td>\n",
" <td>13.0000</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>756</th>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>Hamalainen, Master. Viljo</td>\n",
" <td>male</td>\n",
" <td>0.67</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>250649</td>\n",
" <td>14.5000</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>816</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Fry, Mr. Richard</td>\n",
" <td>male</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>112058</td>\n",
" <td>0.0000</td>\n",
" <td>B102</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>890</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Behr, Mr. Karl Howell</td>\n",
" <td>male</td>\n",
" <td>26.00</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>111369</td>\n",
" <td>30.0000</td>\n",
" <td>C148</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>738</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Lesurer, Mr. Gustave J</td>\n",
" <td>male</td>\n",
" <td>35.00</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>PC 17755</td>\n",
" <td>512.3292</td>\n",
" <td>B101</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Sirayanian, Mr. Orsen</td>\n",
" <td>male</td>\n",
" <td>22.00</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>2669</td>\n",
" <td>7.2292</td>\n",
" <td>NaN</td>\n",
" <td>C</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>712 rows × 11 columns</p>\n",
"</div>"
],
"text/plain": [
" Survived Pclass Name \\\n",
"PassengerId \n",
"145 0 2 Andrew, Mr. Edgardo Samuel \n",
"206 0 3 Strom, Miss. Telma Matilda \n",
"349 1 3 Coutts, Master. William Loch \"William\" \n",
"329 1 3 Goldsmith, Mrs. Frank John (Emily Alice Brown) \n",
"289 1 2 Hosono, Mr. Masabumi \n",
"... ... ... ... \n",
"756 1 2 Hamalainen, Master. Viljo \n",
"816 0 1 Fry, Mr. Richard \n",
"890 1 1 Behr, Mr. Karl Howell \n",
"738 1 1 Lesurer, Mr. Gustave J \n",
"61 0 3 Sirayanian, Mr. Orsen \n",
"\n",
" Sex Age SibSp Parch Ticket Fare Cabin Embarked \n",
"PassengerId \n",
"145 male 18.00 0 0 231945 11.5000 NaN S \n",
"206 female 2.00 0 1 347054 10.4625 G6 S \n",
"349 male 3.00 1 1 C.A. 37671 15.9000 NaN S \n",
"329 female 31.00 1 1 363291 20.5250 NaN S \n",
"289 male 42.00 0 0 237798 13.0000 NaN S \n",
"... ... ... ... ... ... ... ... ... \n",
"756 male 0.67 1 1 250649 14.5000 NaN S \n",
"816 male NaN 0 0 112058 0.0000 B102 S \n",
"890 male 26.00 0 0 111369 30.0000 C148 C \n",
"738 male 35.00 0 0 PC 17755 512.3292 B101 C \n",
"61 male 22.00 0 0 2669 7.2292 NaN C \n",
"\n",
"[712 rows x 11 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Survived</th>\n",
" </tr>\n",
" <tr>\n",
" <th>PassengerId</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>145</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>206</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>349</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>329</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>289</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>756</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>816</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>890</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>738</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>712 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Survived\n",
"PassengerId \n",
"145 0\n",
"206 0\n",
"349 1\n",
"329 1\n",
"289 1\n",
"... ...\n",
"756 1\n",
"816 0\n",
"890 1\n",
"738 1\n",
"61 0\n",
"\n",
"[712 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Survived</th>\n",
" <th>Pclass</th>\n",
" <th>Name</th>\n",
" <th>Sex</th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Parch</th>\n",
" <th>Ticket</th>\n",
" <th>Fare</th>\n",
" <th>Cabin</th>\n",
" <th>Embarked</th>\n",
" </tr>\n",
" <tr>\n",
" <th>PassengerId</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>843</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Serepeca, Miss. Augusta</td>\n",
" <td>female</td>\n",
" <td>30.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>113798</td>\n",
" <td>31.0000</td>\n",
" <td>NaN</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>791</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Keane, Mr. Andrew \"Andy\"</td>\n",
" <td>male</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>12460</td>\n",
" <td>7.7500</td>\n",
" <td>NaN</td>\n",
" <td>Q</td>\n",
" </tr>\n",
" <tr>\n",
" <th>509</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Olsen, Mr. Henry Margido</td>\n",
" <td>male</td>\n",
" <td>28.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>C 4001</td>\n",
" <td>22.5250</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>828</th>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>Mallet, Master. Andre</td>\n",
" <td>male</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>S.C./PARIS 2079</td>\n",
" <td>37.0042</td>\n",
" <td>NaN</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>414</th>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>Cunningham, Mr. Alfred Fleming</td>\n",
" <td>male</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>239853</td>\n",
" <td>0.0000</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>824</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Moor, Mrs. (Beila)</td>\n",
" <td>female</td>\n",
" <td>27.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>392096</td>\n",
" <td>12.4750</td>\n",
" <td>E121</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>353</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Elias, Mr. Tannous</td>\n",
" <td>male</td>\n",
" <td>15.0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2695</td>\n",
" <td>7.2292</td>\n",
" <td>NaN</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>674</th>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>Wilhelms, Mr. Charles</td>\n",
" <td>male</td>\n",
" <td>31.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>244270</td>\n",
" <td>13.0000</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>100</th>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>Kantor, Mr. Sinai</td>\n",
" <td>male</td>\n",
" <td>34.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>244367</td>\n",
" <td>26.0000</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>542</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Andersson, Miss. Ingeborg Constanzia</td>\n",
" <td>female</td>\n",
" <td>9.0</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>347082</td>\n",
" <td>31.2750</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>179 rows × 11 columns</p>\n",
"</div>"
],
"text/plain": [
" Survived Pclass Name Sex \\\n",
"PassengerId \n",
"843 1 1 Serepeca, Miss. Augusta female \n",
"791 0 3 Keane, Mr. Andrew \"Andy\" male \n",
"509 0 3 Olsen, Mr. Henry Margido male \n",
"828 1 2 Mallet, Master. Andre male \n",
"414 0 2 Cunningham, Mr. Alfred Fleming male \n",
"... ... ... ... ... \n",
"824 1 3 Moor, Mrs. (Beila) female \n",
"353 0 3 Elias, Mr. Tannous male \n",
"674 1 2 Wilhelms, Mr. Charles male \n",
"100 0 2 Kantor, Mr. Sinai male \n",
"542 0 3 Andersson, Miss. Ingeborg Constanzia female \n",
"\n",
" Age SibSp Parch Ticket Fare Cabin Embarked \n",
"PassengerId \n",
"843 30.0 0 0 113798 31.0000 NaN C \n",
"791 NaN 0 0 12460 7.7500 NaN Q \n",
"509 28.0 0 0 C 4001 22.5250 NaN S \n",
"828 1.0 0 2 S.C./PARIS 2079 37.0042 NaN C \n",
"414 NaN 0 0 239853 0.0000 NaN S \n",
"... ... ... ... ... ... ... ... \n",
"824 27.0 0 1 392096 12.4750 E121 S \n",
"353 15.0 1 1 2695 7.2292 NaN C \n",
"674 31.0 0 0 244270 13.0000 NaN S \n",
"100 34.0 1 0 244367 26.0000 NaN S \n",
"542 9.0 4 2 347082 31.2750 NaN S \n",
"\n",
"[179 rows x 11 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Survived</th>\n",
" </tr>\n",
" <tr>\n",
" <th>PassengerId</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>843</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>791</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>509</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>828</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>414</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>824</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>353</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>674</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>100</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>542</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>179 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Survived\n",
"PassengerId \n",
"843 1\n",
"791 0\n",
"509 0\n",
"828 1\n",
"414 0\n",
"... ...\n",
"824 1\n",
"353 0\n",
"674 1\n",
"100 0\n",
"542 0\n",
"\n",
"[179 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from src.utils import split_stratified_into_train_val_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"Survived\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование конвейера для классификации данных\n",
"\n",
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
"\n",
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
"\n",
"features_preprocessing -- трансформер для предобработки признаков\n",
"\n",
"features_engineering -- трансформер для конструирования признаков\n",
"\n",
"drop_columns -- трансформер для удаления колонок\n",
"\n",
"features_postprocessing -- трансформер для унитарного кодирования новых признаков\n",
"\n",
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков\n",
"\n",
"Конвейер выполняется последовательно.\n",
"\n",
"Трансформер выполняет параллельно для указанного набора колонок.\n",
"\n",
"Документация: \n",
"\n",
"https://scikit-learn.org/1.5/api/sklearn.pipeline.html\n",
"\n",
"https://scikit-learn.org/1.5/modules/generated/sklearn.compose.ColumnTransformer.html#sklearn.compose.ColumnTransformer"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"from src.transformers import TitanicFeatures\n",
"\n",
"\n",
"columns_to_drop = [\"Survived\", \"Name\", \"Cabin\", \"Ticket\", \"Embarked\", \"Parch\", \"Fare\"]\n",
"num_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop and df[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop and df[column].dtype == \"object\"\n",
"]\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" (\"prepocessing_features\", cat_imputer, [\"Name\", \"Cabin\"]),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"features_engineering = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"add_features\", TitanicFeatures(), [\"Name\", \"Cabin\"]),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"features_postprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_cat\", preprocessing_cat, [\"Cabin_type\"]),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"features_engineering\", features_engineering),\n",
" (\"drop_columns\", drop_columns),\n",
" (\"features_postprocessing\", features_postprocessing),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Демонстрация работы конвейера для предобработки данных при классификации"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Cabin_type_B</th>\n",
" <th>Cabin_type_C</th>\n",
" <th>Cabin_type_D</th>\n",
" <th>Cabin_type_E</th>\n",
" <th>Cabin_type_F</th>\n",
" <th>Cabin_type_G</th>\n",
" <th>Cabin_type_T</th>\n",
" <th>Cabin_type_u</th>\n",
" <th>Is_married</th>\n",
" <th>Pclass</th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Sex_male</th>\n",
" </tr>\n",
" <tr>\n",
" <th>PassengerId</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>145</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>-0.379423</td>\n",
" <td>-0.869506</td>\n",
" <td>-0.473465</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>206</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0.821241</td>\n",
" <td>-2.102186</td>\n",
" <td>-0.473465</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>349</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0.821241</td>\n",
" <td>-2.025143</td>\n",
" <td>0.437635</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>329</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1</td>\n",
" <td>0.821241</td>\n",
" <td>0.132047</td>\n",
" <td>0.437635</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>289</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>-0.379423</td>\n",
" <td>0.979514</td>\n",
" <td>-0.473465</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>756</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>-0.379423</td>\n",
" <td>-2.204652</td>\n",
" <td>0.437635</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>816</th>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>-1.580088</td>\n",
" <td>-0.099081</td>\n",
" <td>-0.473465</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>890</th>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>-1.580088</td>\n",
" <td>-0.253166</td>\n",
" <td>-0.473465</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>738</th>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>-1.580088</td>\n",
" <td>0.440217</td>\n",
" <td>-0.473465</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0.821241</td>\n",
" <td>-0.561336</td>\n",
" <td>-0.473465</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>712 rows × 13 columns</p>\n",
"</div>"
],
"text/plain": [
" Cabin_type_B Cabin_type_C Cabin_type_D Cabin_type_E \\\n",
"PassengerId \n",
"145 0.0 0.0 0.0 0.0 \n",
"206 0.0 0.0 0.0 0.0 \n",
"349 0.0 0.0 0.0 0.0 \n",
"329 0.0 0.0 0.0 0.0 \n",
"289 0.0 0.0 0.0 0.0 \n",
"... ... ... ... ... \n",
"756 0.0 0.0 0.0 0.0 \n",
"816 1.0 0.0 0.0 0.0 \n",
"890 0.0 1.0 0.0 0.0 \n",
"738 1.0 0.0 0.0 0.0 \n",
"61 0.0 0.0 0.0 0.0 \n",
"\n",
" Cabin_type_F Cabin_type_G Cabin_type_T Cabin_type_u \\\n",
"PassengerId \n",
"145 0.0 0.0 0.0 1.0 \n",
"206 0.0 1.0 0.0 0.0 \n",
"349 0.0 0.0 0.0 1.0 \n",
"329 0.0 0.0 0.0 1.0 \n",
"289 0.0 0.0 0.0 1.0 \n",
"... ... ... ... ... \n",
"756 0.0 0.0 0.0 1.0 \n",
"816 0.0 0.0 0.0 0.0 \n",
"890 0.0 0.0 0.0 0.0 \n",
"738 0.0 0.0 0.0 0.0 \n",
"61 0.0 0.0 0.0 1.0 \n",
"\n",
" Is_married Pclass Age SibSp Sex_male \n",
"PassengerId \n",
"145 0 -0.379423 -0.869506 -0.473465 1.0 \n",
"206 0 0.821241 -2.102186 -0.473465 0.0 \n",
"349 0 0.821241 -2.025143 0.437635 1.0 \n",
"329 1 0.821241 0.132047 0.437635 0.0 \n",
"289 0 -0.379423 0.979514 -0.473465 1.0 \n",
"... ... ... ... ... ... \n",
"756 0 -0.379423 -2.204652 0.437635 1.0 \n",
"816 0 -1.580088 -0.099081 -0.473465 1.0 \n",
"890 0 -1.580088 -0.253166 -0.473465 1.0 \n",
"738 0 -1.580088 0.440217 -0.473465 1.0 \n",
"61 0 0.821241 -0.561336 -0.473465 1.0 \n",
"\n",
"[712 rows x 13 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование набора моделей для классификации\n",
"\n",
"logistic -- логистическая регрессия\n",
"\n",
"ridge -- гребневая регрессия\n",
"\n",
"decision_tree -- дерево решений\n",
"\n",
"knn -- k-ближайших соседей\n",
"\n",
"naive_bayes -- наивный Байесовский классификатор\n",
"\n",
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
"\n",
"random_forest -- метод случайного леса (набор деревьев решений)\n",
"\n",
"mlp -- многослойный персептрон (нейронная сеть)\n",
"\n",
"Документация: https://scikit-learn.org/1.5/supervised_learning.html"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение моделей на обучающем наборе данных и оценка на тестовом"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"from src.utils import run_classification\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)]).fit(\n",
" X_train, y_train.values.ravel()\n",
" )\n",
"\n",
" class_models[model_name] = run_classification(\n",
" pipeline, X_train, X_test, y_train, y_test\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Сводная таблица оценок качества для использованных моделей классификации\n",
"\n",
"Документация: https://scikit-learn.org/1.5/modules/model_evaluation.html"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Матрица неточностей"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Died\", \"Sirvived\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_4d92f_row0_col0, #T_4d92f_row7_col1 {\n",
" background-color: #a2da37;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row0_col1, #T_4d92f_row3_col2, #T_4d92f_row5_col3 {\n",
" background-color: #77d153;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row0_col2 {\n",
" background-color: #a5db36;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row0_col3 {\n",
" background-color: #7cd250;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row0_col4, #T_4d92f_row0_col5, #T_4d92f_row0_col6, #T_4d92f_row0_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row1_col0 {\n",
" background-color: #a0da39;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row1_col1 {\n",
" background-color: #7fd34e;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row1_col2 {\n",
" background-color: #93d741;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row1_col3 {\n",
" background-color: #70cf57;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row1_col4 {\n",
" background-color: #d45270;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row1_col5, #T_4d92f_row1_col6 {\n",
" background-color: #d6556d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row1_col7 {\n",
" background-color: #d7566c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row2_col0 {\n",
" background-color: #3bbb75;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row2_col1 {\n",
" background-color: #84d44b;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row2_col2, #T_4d92f_row4_col0 {\n",
" background-color: #6ece58;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row2_col3 {\n",
" background-color: #65cb5e;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row2_col4, #T_4d92f_row5_col4 {\n",
" background-color: #a62098;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row2_col5, #T_4d92f_row3_col5, #T_4d92f_row4_col5, #T_4d92f_row4_col7 {\n",
" background-color: #d35171;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row2_col6 {\n",
" background-color: #c03a83;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row2_col7 {\n",
" background-color: #d5536f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row3_col0 {\n",
" background-color: #81d34d;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row3_col1, #T_4d92f_row6_col2, #T_4d92f_row6_col3, #T_4d92f_row7_col0 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row3_col3 {\n",
" background-color: #56c667;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row3_col4 {\n",
" background-color: #c33d80;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row3_col6, #T_4d92f_row5_col7 {\n",
" background-color: #cc4977;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row3_col7 {\n",
" background-color: #d24f71;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row4_col1 {\n",
" background-color: #9bd93c;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row4_col2 {\n",
" background-color: #6ccd5a;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row4_col3 {\n",
" background-color: #5cc863;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row4_col4 {\n",
" background-color: #b83289;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row4_col6 {\n",
" background-color: #c7427c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row5_col0 {\n",
" background-color: #2db27d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row5_col1 {\n",
" background-color: #26ad81;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row5_col2 {\n",
" background-color: #89d548;\n",
" color: #000000;\n",
"}\n",
"#T_4d92f_row5_col5 {\n",
" background-color: #ae2892;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row5_col6 {\n",
" background-color: #c43e7f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row6_col0, #T_4d92f_row6_col1, #T_4d92f_row7_col2, #T_4d92f_row7_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row6_col4, #T_4d92f_row7_col5, #T_4d92f_row7_col6, #T_4d92f_row7_col7 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row6_col5 {\n",
" background-color: #6700a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row6_col6 {\n",
" background-color: #b32c8e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row6_col7 {\n",
" background-color: #c5407e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d92f_row7_col4 {\n",
" background-color: #5002a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_4d92f\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_4d92f_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_4d92f_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_4d92f_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_4d92f_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_4d92f_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_4d92f_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_4d92f_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_4d92f_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_4d92f_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
" <td id=\"T_4d92f_row0_col0\" class=\"data row0 col0\" >0.894340</td>\n",
" <td id=\"T_4d92f_row0_col1\" class=\"data row0 col1\" >0.794118</td>\n",
" <td id=\"T_4d92f_row0_col2\" class=\"data row0 col2\" >0.868132</td>\n",
" <td id=\"T_4d92f_row0_col3\" class=\"data row0 col3\" >0.782609</td>\n",
" <td id=\"T_4d92f_row0_col4\" class=\"data row0 col4\" >0.910112</td>\n",
" <td id=\"T_4d92f_row0_col5\" class=\"data row0 col5\" >0.837989</td>\n",
" <td id=\"T_4d92f_row0_col6\" class=\"data row0 col6\" >0.881041</td>\n",
" <td id=\"T_4d92f_row0_col7\" class=\"data row0 col7\" >0.788321</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4d92f_level0_row1\" class=\"row_heading level0 row1\" >gradient_boosting</th>\n",
" <td id=\"T_4d92f_row1_col0\" class=\"data row1 col0\" >0.889764</td>\n",
" <td id=\"T_4d92f_row1_col1\" class=\"data row1 col1\" >0.800000</td>\n",
" <td id=\"T_4d92f_row1_col2\" class=\"data row1 col2\" >0.827839</td>\n",
" <td id=\"T_4d92f_row1_col3\" class=\"data row1 col3\" >0.753623</td>\n",
" <td id=\"T_4d92f_row1_col4\" class=\"data row1 col4\" >0.894663</td>\n",
" <td id=\"T_4d92f_row1_col5\" class=\"data row1 col5\" >0.832402</td>\n",
" <td id=\"T_4d92f_row1_col6\" class=\"data row1 col6\" >0.857685</td>\n",
" <td id=\"T_4d92f_row1_col7\" class=\"data row1 col7\" >0.776119</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4d92f_level0_row2\" class=\"row_heading level0 row2\" >logistic</th>\n",
" <td id=\"T_4d92f_row2_col0\" class=\"data row2 col0\" >0.751880</td>\n",
" <td id=\"T_4d92f_row2_col1\" class=\"data row2 col1\" >0.806452</td>\n",
" <td id=\"T_4d92f_row2_col2\" class=\"data row2 col2\" >0.732601</td>\n",
" <td id=\"T_4d92f_row2_col3\" class=\"data row2 col3\" >0.724638</td>\n",
" <td id=\"T_4d92f_row2_col4\" class=\"data row2 col4\" >0.804775</td>\n",
" <td id=\"T_4d92f_row2_col5\" class=\"data row2 col5\" >0.826816</td>\n",
" <td id=\"T_4d92f_row2_col6\" class=\"data row2 col6\" >0.742115</td>\n",
" <td id=\"T_4d92f_row2_col7\" class=\"data row2 col7\" >0.763359</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4d92f_level0_row3\" class=\"row_heading level0 row3\" >decision_tree</th>\n",
" <td id=\"T_4d92f_row3_col0\" class=\"data row3 col0\" >0.852459</td>\n",
" <td id=\"T_4d92f_row3_col1\" class=\"data row3 col1\" >0.839286</td>\n",
" <td id=\"T_4d92f_row3_col2\" class=\"data row3 col2\" >0.761905</td>\n",
" <td id=\"T_4d92f_row3_col3\" class=\"data row3 col3\" >0.681159</td>\n",
" <td id=\"T_4d92f_row3_col4\" class=\"data row3 col4\" >0.858146</td>\n",
" <td id=\"T_4d92f_row3_col5\" class=\"data row3 col5\" >0.826816</td>\n",
" <td id=\"T_4d92f_row3_col6\" class=\"data row3 col6\" >0.804642</td>\n",
" <td id=\"T_4d92f_row3_col7\" class=\"data row3 col7\" >0.752000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4d92f_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
" <td id=\"T_4d92f_row4_col0\" class=\"data row4 col0\" >0.829167</td>\n",
" <td id=\"T_4d92f_row4_col1\" class=\"data row4 col1\" >0.827586</td>\n",
" <td id=\"T_4d92f_row4_col2\" class=\"data row4 col2\" >0.728938</td>\n",
" <td id=\"T_4d92f_row4_col3\" class=\"data row4 col3\" >0.695652</td>\n",
" <td id=\"T_4d92f_row4_col4\" class=\"data row4 col4\" >0.838483</td>\n",
" <td id=\"T_4d92f_row4_col5\" class=\"data row4 col5\" >0.826816</td>\n",
" <td id=\"T_4d92f_row4_col6\" class=\"data row4 col6\" >0.775828</td>\n",
" <td id=\"T_4d92f_row4_col7\" class=\"data row4 col7\" >0.755906</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4d92f_level0_row5\" class=\"row_heading level0 row5\" >ridge</th>\n",
" <td id=\"T_4d92f_row5_col0\" class=\"data row5 col0\" >0.720395</td>\n",
" <td id=\"T_4d92f_row5_col1\" class=\"data row5 col1\" >0.688312</td>\n",
" <td id=\"T_4d92f_row5_col2\" class=\"data row5 col2\" >0.802198</td>\n",
" <td id=\"T_4d92f_row5_col3\" class=\"data row5 col3\" >0.768116</td>\n",
" <td id=\"T_4d92f_row5_col4\" class=\"data row5 col4\" >0.804775</td>\n",
" <td id=\"T_4d92f_row5_col5\" class=\"data row5 col5\" >0.776536</td>\n",
" <td id=\"T_4d92f_row5_col6\" class=\"data row5 col6\" >0.759099</td>\n",
" <td id=\"T_4d92f_row5_col7\" class=\"data row5 col7\" >0.726027</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4d92f_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
" <td id=\"T_4d92f_row6_col0\" class=\"data row6 col0\" >0.554524</td>\n",
" <td id=\"T_4d92f_row6_col1\" class=\"data row6 col1\" >0.575472</td>\n",
" <td id=\"T_4d92f_row6_col2\" class=\"data row6 col2\" >0.875458</td>\n",
" <td id=\"T_4d92f_row6_col3\" class=\"data row6 col3\" >0.884058</td>\n",
" <td id=\"T_4d92f_row6_col4\" class=\"data row6 col4\" >0.682584</td>\n",
" <td id=\"T_4d92f_row6_col5\" class=\"data row6 col5\" >0.703911</td>\n",
" <td id=\"T_4d92f_row6_col6\" class=\"data row6 col6\" >0.678977</td>\n",
" <td id=\"T_4d92f_row6_col7\" class=\"data row6 col7\" >0.697143</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_4d92f_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_4d92f_row7_col0\" class=\"data row7 col0\" >0.900000</td>\n",
" <td id=\"T_4d92f_row7_col1\" class=\"data row7 col1\" >0.833333</td>\n",
" <td id=\"T_4d92f_row7_col2\" class=\"data row7 col2\" >0.197802</td>\n",
" <td id=\"T_4d92f_row7_col3\" class=\"data row7 col3\" >0.217391</td>\n",
" <td id=\"T_4d92f_row7_col4\" class=\"data row7 col4\" >0.683989</td>\n",
" <td id=\"T_4d92f_row7_col5\" class=\"data row7 col5\" >0.681564</td>\n",
" <td id=\"T_4d92f_row7_col6\" class=\"data row7 col6\" >0.324324</td>\n",
" <td id=\"T_4d92f_row7_col7\" class=\"data row7 col7\" >0.344828</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x2827fc5e2d0>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_71e02_row0_col0, #T_71e02_row0_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_71e02_row0_col2, #T_71e02_row0_col3, #T_71e02_row0_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row1_col0, #T_71e02_row4_col0, #T_71e02_row4_col1, #T_71e02_row5_col0 {\n",
" background-color: #93d741;\n",
" color: #000000;\n",
"}\n",
"#T_71e02_row1_col1 {\n",
" background-color: #98d83e;\n",
" color: #000000;\n",
"}\n",
"#T_71e02_row1_col2 {\n",
" background-color: #d7566c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row1_col3 {\n",
" background-color: #d45270;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row1_col4, #T_71e02_row4_col4, #T_71e02_row5_col3, #T_71e02_row5_col4 {\n",
" background-color: #d24f71;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row2_col0 {\n",
" background-color: #42be71;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row2_col1 {\n",
" background-color: #7fd34e;\n",
" color: #000000;\n",
"}\n",
"#T_71e02_row2_col2, #T_71e02_row3_col2 {\n",
" background-color: #d5536f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row2_col3 {\n",
" background-color: #be3885;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row2_col4 {\n",
" background-color: #b6308b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row3_col0 {\n",
" background-color: #9dd93b;\n",
" color: #000000;\n",
"}\n",
"#T_71e02_row3_col1 {\n",
" background-color: #a0da39;\n",
" color: #000000;\n",
"}\n",
"#T_71e02_row3_col3, #T_71e02_row3_col4 {\n",
" background-color: #d6556d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row4_col2 {\n",
" background-color: #cc4977;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row4_col3 {\n",
" background-color: #d35171;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row5_col1 {\n",
" background-color: #90d743;\n",
" color: #000000;\n",
"}\n",
"#T_71e02_row5_col2 {\n",
" background-color: #a82296;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row6_col0 {\n",
" background-color: #21908d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row6_col1 {\n",
" background-color: #6ece58;\n",
" color: #000000;\n",
"}\n",
"#T_71e02_row6_col2 {\n",
" background-color: #a11b9b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row6_col3 {\n",
" background-color: #9e199d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row6_col4 {\n",
" background-color: #9c179e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row7_col0, #T_71e02_row7_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_71e02_row7_col2, #T_71e02_row7_col3, #T_71e02_row7_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_71e02\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_71e02_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_71e02_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_71e02_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_71e02_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_71e02_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_71e02_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
" <td id=\"T_71e02_row0_col0\" class=\"data row0 col0\" >0.837989</td>\n",
" <td id=\"T_71e02_row0_col1\" class=\"data row0 col1\" >0.788321</td>\n",
" <td id=\"T_71e02_row0_col2\" class=\"data row0 col2\" >0.858893</td>\n",
" <td id=\"T_71e02_row0_col3\" class=\"data row0 col3\" >0.657111</td>\n",
" <td id=\"T_71e02_row0_col4\" class=\"data row0 col4\" >0.657157</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_71e02_level0_row1\" class=\"row_heading level0 row1\" >logistic</th>\n",
" <td id=\"T_71e02_row1_col0\" class=\"data row1 col0\" >0.826816</td>\n",
" <td id=\"T_71e02_row1_col1\" class=\"data row1 col1\" >0.763359</td>\n",
" <td id=\"T_71e02_row1_col2\" class=\"data row1 col2\" >0.854084</td>\n",
" <td id=\"T_71e02_row1_col3\" class=\"data row1 col3\" >0.627409</td>\n",
" <td id=\"T_71e02_row1_col4\" class=\"data row1 col4\" >0.629641</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_71e02_level0_row2\" class=\"row_heading level0 row2\" >ridge</th>\n",
" <td id=\"T_71e02_row2_col0\" class=\"data row2 col0\" >0.776536</td>\n",
" <td id=\"T_71e02_row2_col1\" class=\"data row2 col1\" >0.726027</td>\n",
" <td id=\"T_71e02_row2_col2\" class=\"data row2 col2\" >0.851054</td>\n",
" <td id=\"T_71e02_row2_col3\" class=\"data row2 col3\" >0.538303</td>\n",
" <td id=\"T_71e02_row2_col4\" class=\"data row2 col4\" >0.540613</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_71e02_level0_row3\" class=\"row_heading level0 row3\" >gradient_boosting</th>\n",
" <td id=\"T_71e02_row3_col0\" class=\"data row3 col0\" >0.832402</td>\n",
" <td id=\"T_71e02_row3_col1\" class=\"data row3 col1\" >0.776119</td>\n",
" <td id=\"T_71e02_row3_col2\" class=\"data row3 col2\" >0.850922</td>\n",
" <td id=\"T_71e02_row3_col3\" class=\"data row3 col3\" >0.642381</td>\n",
" <td id=\"T_71e02_row3_col4\" class=\"data row3 col4\" >0.643113</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_71e02_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
" <td id=\"T_71e02_row4_col0\" class=\"data row4 col0\" >0.826816</td>\n",
" <td id=\"T_71e02_row4_col1\" class=\"data row4 col1\" >0.755906</td>\n",
" <td id=\"T_71e02_row4_col2\" class=\"data row4 col2\" >0.838735</td>\n",
" <td id=\"T_71e02_row4_col3\" class=\"data row4 col3\" >0.623260</td>\n",
" <td id=\"T_71e02_row4_col4\" class=\"data row4 col4\" >0.628905</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_71e02_level0_row5\" class=\"row_heading level0 row5\" >decision_tree</th>\n",
" <td id=\"T_71e02_row5_col0\" class=\"data row5 col0\" >0.826816</td>\n",
" <td id=\"T_71e02_row5_col1\" class=\"data row5 col1\" >0.752000</td>\n",
" <td id=\"T_71e02_row5_col2\" class=\"data row5 col2\" >0.794137</td>\n",
" <td id=\"T_71e02_row5_col3\" class=\"data row5 col3\" >0.621151</td>\n",
" <td id=\"T_71e02_row5_col4\" class=\"data row5 col4\" >0.629142</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_71e02_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
" <td id=\"T_71e02_row6_col0\" class=\"data row6 col0\" >0.703911</td>\n",
" <td id=\"T_71e02_row6_col1\" class=\"data row6 col1\" >0.697143</td>\n",
" <td id=\"T_71e02_row6_col2\" class=\"data row6 col2\" >0.785903</td>\n",
" <td id=\"T_71e02_row6_col3\" class=\"data row6 col3\" >0.431814</td>\n",
" <td id=\"T_71e02_row6_col4\" class=\"data row6 col4\" >0.470403</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_71e02_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_71e02_row7_col0\" class=\"data row7 col0\" >0.681564</td>\n",
" <td id=\"T_71e02_row7_col1\" class=\"data row7 col1\" >0.344828</td>\n",
" <td id=\"T_71e02_row7_col2\" class=\"data row7 col2\" >0.712714</td>\n",
" <td id=\"T_71e02_row7_col3\" class=\"data row7 col3\" >0.220490</td>\n",
" <td id=\"T_71e02_row7_col4\" class=\"data row7 col4\" >0.307678</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x2827fc885f0>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'random_forest'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 29'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Survived</th>\n",
" <th>Predicted</th>\n",
" <th>Pclass</th>\n",
" <th>Name</th>\n",
" <th>Sex</th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Parch</th>\n",
" <th>Ticket</th>\n",
" <th>Fare</th>\n",
" <th>Cabin</th>\n",
" <th>Embarked</th>\n",
" </tr>\n",
" <tr>\n",
" <th>PassengerId</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>26</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Asplund, Mrs. Carl Oscar (Selma Augusta Emilia...</td>\n",
" <td>female</td>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>5</td>\n",
" <td>347077</td>\n",
" <td>31.3875</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>72</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Goodwin, Miss. Lillian Amy</td>\n",
" <td>female</td>\n",
" <td>16.0</td>\n",
" <td>5</td>\n",
" <td>2</td>\n",
" <td>CA 2144</td>\n",
" <td>46.9000</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>103</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>White, Mr. Richard Frasar</td>\n",
" <td>male</td>\n",
" <td>21.0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>35281</td>\n",
" <td>77.2875</td>\n",
" <td>D26</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>108</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Moss, Mr. Albert Johan</td>\n",
" <td>male</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>312991</td>\n",
" <td>7.7750</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>128</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Madsen, Mr. Fridtjof Arne</td>\n",
" <td>male</td>\n",
" <td>24.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>C 17369</td>\n",
" <td>7.1417</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>193</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Andersen-Jensen, Miss. Carla Christine Nielsine</td>\n",
" <td>female</td>\n",
" <td>19.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>350046</td>\n",
" <td>7.8542</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>241</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Zabour, Miss. Thamine</td>\n",
" <td>female</td>\n",
" <td>NaN</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>2665</td>\n",
" <td>14.4542</td>\n",
" <td>NaN</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>272</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Tornquist, Mr. William Henry</td>\n",
" <td>male</td>\n",
" <td>25.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>LINE</td>\n",
" <td>0.0000</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>293</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>Levy, Mr. Rene Jacques</td>\n",
" <td>male</td>\n",
" <td>36.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>SC/Paris 2163</td>\n",
" <td>12.8750</td>\n",
" <td>D</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>352</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Williams-Lambert, Mr. Fletcher Fellows</td>\n",
" <td>male</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>113510</td>\n",
" <td>35.0000</td>\n",
" <td>C128</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>358</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>Funk, Miss. Annie Clemmer</td>\n",
" <td>female</td>\n",
" <td>38.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>237671</td>\n",
" <td>13.0000</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>378</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Widener, Mr. Harry Elkins</td>\n",
" <td>male</td>\n",
" <td>27.0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>113503</td>\n",
" <td>211.5000</td>\n",
" <td>C82</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>445</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Johannesen-Bratthammer, Mr. Bernt</td>\n",
" <td>male</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>65306</td>\n",
" <td>8.1125</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>450</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Peuchen, Major. Arthur Godfrey</td>\n",
" <td>male</td>\n",
" <td>52.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>113786</td>\n",
" <td>30.5000</td>\n",
" <td>C104</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>508</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Bradley, Mr. George (\"George Arthur Brayton\")</td>\n",
" <td>male</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>111427</td>\n",
" <td>26.5500</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>511</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Daly, Mr. Eugene Patrick</td>\n",
" <td>male</td>\n",
" <td>29.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>382651</td>\n",
" <td>7.7500</td>\n",
" <td>NaN</td>\n",
" <td>Q</td>\n",
" </tr>\n",
" <tr>\n",
" <th>570</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Jonsson, Mr. Carl</td>\n",
" <td>male</td>\n",
" <td>32.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>350417</td>\n",
" <td>7.8542</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>579</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Caram, Mrs. Joseph (Maria Elias)</td>\n",
" <td>female</td>\n",
" <td>NaN</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>2689</td>\n",
" <td>14.4583</td>\n",
" <td>NaN</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>584</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Ross, Mr. John Hugo</td>\n",
" <td>male</td>\n",
" <td>36.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>13049</td>\n",
" <td>40.1250</td>\n",
" <td>A10</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>588</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Frolicher-Stehli, Mr. Maxmillian</td>\n",
" <td>male</td>\n",
" <td>60.0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>13567</td>\n",
" <td>79.2000</td>\n",
" <td>B41</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>618</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Lobb, Mrs. William Arthur (Cordelia K Stanlick)</td>\n",
" <td>female</td>\n",
" <td>26.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>A/5. 3336</td>\n",
" <td>16.1000</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>658</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Bourke, Mrs. John (Catherine)</td>\n",
" <td>female</td>\n",
" <td>32.0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>364849</td>\n",
" <td>15.5000</td>\n",
" <td>NaN</td>\n",
" <td>Q</td>\n",
" </tr>\n",
" <tr>\n",
" <th>661</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>Frauenthal, Dr. Henry William</td>\n",
" <td>male</td>\n",
" <td>50.0</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>PC 17611</td>\n",
" <td>133.6500</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>674</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>Wilhelms, Mr. Charles</td>\n",
" <td>male</td>\n",
" <td>31.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>244270</td>\n",
" <td>13.0000</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>745</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Stranden, Mr. Juho</td>\n",
" <td>male</td>\n",
" <td>31.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>STON/O 2. 3101288</td>\n",
" <td>7.9250</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>773</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>Mack, Mrs. (Mary)</td>\n",
" <td>female</td>\n",
" <td>57.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>S.O./P.P. 3</td>\n",
" <td>10.5000</td>\n",
" <td>E77</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>807</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Andrews, Mr. Thomas Jr</td>\n",
" <td>male</td>\n",
" <td>39.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>112050</td>\n",
" <td>0.0000</td>\n",
" <td>A36</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>814</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Andersson, Miss. Ebba Iris Alfrida</td>\n",
" <td>female</td>\n",
" <td>6.0</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>347082</td>\n",
" <td>31.2750</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>829</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>McCormack, Mr. Thomas Joseph</td>\n",
" <td>male</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>367228</td>\n",
" <td>7.7500</td>\n",
" <td>NaN</td>\n",
" <td>Q</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Survived Predicted Pclass \\\n",
"PassengerId \n",
"26 1 0 3 \n",
"72 0 1 3 \n",
"103 0 1 1 \n",
"108 1 0 3 \n",
"128 1 0 3 \n",
"193 1 0 3 \n",
"241 0 1 3 \n",
"272 1 0 3 \n",
"293 0 1 2 \n",
"352 0 1 1 \n",
"358 0 1 2 \n",
"378 0 1 1 \n",
"445 1 0 3 \n",
"450 1 0 1 \n",
"508 1 0 1 \n",
"511 1 0 3 \n",
"570 1 0 3 \n",
"579 0 1 3 \n",
"584 0 1 1 \n",
"588 1 0 1 \n",
"618 0 1 3 \n",
"658 0 1 3 \n",
"661 1 0 1 \n",
"674 1 0 2 \n",
"745 1 0 3 \n",
"773 0 1 2 \n",
"807 0 1 1 \n",
"814 0 1 3 \n",
"829 1 0 3 \n",
"\n",
" Name Sex Age \\\n",
"PassengerId \n",
"26 Asplund, Mrs. Carl Oscar (Selma Augusta Emilia... female 38.0 \n",
"72 Goodwin, Miss. Lillian Amy female 16.0 \n",
"103 White, Mr. Richard Frasar male 21.0 \n",
"108 Moss, Mr. Albert Johan male NaN \n",
"128 Madsen, Mr. Fridtjof Arne male 24.0 \n",
"193 Andersen-Jensen, Miss. Carla Christine Nielsine female 19.0 \n",
"241 Zabour, Miss. Thamine female NaN \n",
"272 Tornquist, Mr. William Henry male 25.0 \n",
"293 Levy, Mr. Rene Jacques male 36.0 \n",
"352 Williams-Lambert, Mr. Fletcher Fellows male NaN \n",
"358 Funk, Miss. Annie Clemmer female 38.0 \n",
"378 Widener, Mr. Harry Elkins male 27.0 \n",
"445 Johannesen-Bratthammer, Mr. Bernt male NaN \n",
"450 Peuchen, Major. Arthur Godfrey male 52.0 \n",
"508 Bradley, Mr. George (\"George Arthur Brayton\") male NaN \n",
"511 Daly, Mr. Eugene Patrick male 29.0 \n",
"570 Jonsson, Mr. Carl male 32.0 \n",
"579 Caram, Mrs. Joseph (Maria Elias) female NaN \n",
"584 Ross, Mr. John Hugo male 36.0 \n",
"588 Frolicher-Stehli, Mr. Maxmillian male 60.0 \n",
"618 Lobb, Mrs. William Arthur (Cordelia K Stanlick) female 26.0 \n",
"658 Bourke, Mrs. John (Catherine) female 32.0 \n",
"661 Frauenthal, Dr. Henry William male 50.0 \n",
"674 Wilhelms, Mr. Charles male 31.0 \n",
"745 Stranden, Mr. Juho male 31.0 \n",
"773 Mack, Mrs. (Mary) female 57.0 \n",
"807 Andrews, Mr. Thomas Jr male 39.0 \n",
"814 Andersson, Miss. Ebba Iris Alfrida female 6.0 \n",
"829 McCormack, Mr. Thomas Joseph male NaN \n",
"\n",
" SibSp Parch Ticket Fare Cabin Embarked \n",
"PassengerId \n",
"26 1 5 347077 31.3875 NaN S \n",
"72 5 2 CA 2144 46.9000 NaN S \n",
"103 0 1 35281 77.2875 D26 S \n",
"108 0 0 312991 7.7750 NaN S \n",
"128 0 0 C 17369 7.1417 NaN S \n",
"193 1 0 350046 7.8542 NaN S \n",
"241 1 0 2665 14.4542 NaN C \n",
"272 0 0 LINE 0.0000 NaN S \n",
"293 0 0 SC/Paris 2163 12.8750 D C \n",
"352 0 0 113510 35.0000 C128 S \n",
"358 0 0 237671 13.0000 NaN S \n",
"378 0 2 113503 211.5000 C82 C \n",
"445 0 0 65306 8.1125 NaN S \n",
"450 0 0 113786 30.5000 C104 S \n",
"508 0 0 111427 26.5500 NaN S \n",
"511 0 0 382651 7.7500 NaN Q \n",
"570 0 0 350417 7.8542 NaN S \n",
"579 1 0 2689 14.4583 NaN C \n",
"584 0 0 13049 40.1250 A10 C \n",
"588 1 1 13567 79.2000 B41 C \n",
"618 1 0 A/5. 3336 16.1000 NaN S \n",
"658 1 1 364849 15.5000 NaN Q \n",
"661 2 0 PC 17611 133.6500 NaN S \n",
"674 0 0 244270 13.0000 NaN S \n",
"745 0 0 STON/O 2. 3101288 7.9250 NaN S \n",
"773 0 0 S.O./P.P. 3 10.5000 E77 S \n",
"807 0 0 112050 0.0000 A36 S \n",
"814 4 2 347082 31.2750 NaN S \n",
"829 0 0 367228 7.7500 NaN Q "
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"Survived\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Пример использования обученной модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Survived</th>\n",
" <th>Pclass</th>\n",
" <th>Name</th>\n",
" <th>Sex</th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Parch</th>\n",
" <th>Ticket</th>\n",
" <th>Fare</th>\n",
" <th>Cabin</th>\n",
" <th>Embarked</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>450</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Peuchen, Major. Arthur Godfrey</td>\n",
" <td>male</td>\n",
" <td>52.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>113786</td>\n",
" <td>30.5</td>\n",
" <td>C104</td>\n",
" <td>S</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Survived Pclass Name Sex Age SibSp Parch \\\n",
"450 1 1 Peuchen, Major. Arthur Godfrey male 52.0 0 0 \n",
"\n",
" Ticket Fare Cabin Embarked \n",
"450 113786 30.5 C104 S "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Cabin_type_B</th>\n",
" <th>Cabin_type_C</th>\n",
" <th>Cabin_type_D</th>\n",
" <th>Cabin_type_E</th>\n",
" <th>Cabin_type_F</th>\n",
" <th>Cabin_type_G</th>\n",
" <th>Cabin_type_T</th>\n",
" <th>Cabin_type_u</th>\n",
" <th>Is_married</th>\n",
" <th>Pclass</th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Sex_male</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>450</th>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>-1.580088</td>\n",
" <td>1.749939</td>\n",
" <td>-0.473465</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Cabin_type_B Cabin_type_C Cabin_type_D Cabin_type_E Cabin_type_F \\\n",
"450 0.0 1.0 0.0 0.0 0.0 \n",
"\n",
" Cabin_type_G Cabin_type_T Cabin_type_u Is_married Pclass Age \\\n",
"450 0.0 0.0 0.0 0.0 -1.580088 1.749939 \n",
"\n",
" SibSp Sex_male \n",
"450 -0.473465 1.0 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: 0 (proba: [0.91145747 0.08854253])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 1'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 450\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Подбор гиперпараметров методом поиска по сетке\n",
"\n",
"https://www.kaggle.com/code/sociopath00/random-forest-using-gridsearchcv\n",
"\n",
"https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\user\\Projects\\python\\ckmai\\.venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 7,\n",
" 'model__max_features': 'sqrt',\n",
" 'model__n_estimators': 30}"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"optimized_model_type = \"random_forest\"\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
" \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10],\n",
" \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"pipeline = gs_optomizer.best_estimator_.fit(X_train, y_train.values.ravel())\n",
"\n",
"result = run_classification(pipeline, X_train, X_test, y_train, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_52c7b_row0_col0, #T_52c7b_row0_col1, #T_52c7b_row0_col2, #T_52c7b_row1_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_52c7b_row0_col3, #T_52c7b_row1_col0, #T_52c7b_row1_col1, #T_52c7b_row1_col2 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_52c7b_row0_col4, #T_52c7b_row0_col6, #T_52c7b_row1_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_52c7b_row0_col5, #T_52c7b_row1_col5 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_52c7b_row0_col7, #T_52c7b_row1_col4, #T_52c7b_row1_col6 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_52c7b\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_52c7b_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_52c7b_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_52c7b_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_52c7b_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_52c7b_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_52c7b_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_52c7b_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_52c7b_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" <th class=\"blank col7\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_52c7b_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_52c7b_row0_col0\" class=\"data row0 col0\" >0.894340</td>\n",
" <td id=\"T_52c7b_row0_col1\" class=\"data row0 col1\" >0.794118</td>\n",
" <td id=\"T_52c7b_row0_col2\" class=\"data row0 col2\" >0.868132</td>\n",
" <td id=\"T_52c7b_row0_col3\" class=\"data row0 col3\" >0.782609</td>\n",
" <td id=\"T_52c7b_row0_col4\" class=\"data row0 col4\" >0.910112</td>\n",
" <td id=\"T_52c7b_row0_col5\" class=\"data row0 col5\" >0.837989</td>\n",
" <td id=\"T_52c7b_row0_col6\" class=\"data row0 col6\" >0.881041</td>\n",
" <td id=\"T_52c7b_row0_col7\" class=\"data row0 col7\" >0.788321</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_52c7b_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_52c7b_row1_col0\" class=\"data row1 col0\" >0.800699</td>\n",
" <td id=\"T_52c7b_row1_col1\" class=\"data row1 col1\" >0.777778</td>\n",
" <td id=\"T_52c7b_row1_col2\" class=\"data row1 col2\" >0.838828</td>\n",
" <td id=\"T_52c7b_row1_col3\" class=\"data row1 col3\" >0.811594</td>\n",
" <td id=\"T_52c7b_row1_col4\" class=\"data row1 col4\" >0.858146</td>\n",
" <td id=\"T_52c7b_row1_col5\" class=\"data row1 col5\" >0.837989</td>\n",
" <td id=\"T_52c7b_row1_col6\" class=\"data row1 col6\" >0.819320</td>\n",
" <td id=\"T_52c7b_row1_col7\" class=\"data row1 col7\" >0.794326</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x28221285d30>"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_3ee0b_row0_col0, #T_3ee0b_row1_col0 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_3ee0b_row0_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_3ee0b_row0_col2, #T_3ee0b_row0_col3, #T_3ee0b_row0_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_3ee0b_row1_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_3ee0b_row1_col2, #T_3ee0b_row1_col3, #T_3ee0b_row1_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_3ee0b\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_3ee0b_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_3ee0b_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_3ee0b_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_3ee0b_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_3ee0b_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_3ee0b_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_3ee0b_row0_col0\" class=\"data row0 col0\" >0.837989</td>\n",
" <td id=\"T_3ee0b_row0_col1\" class=\"data row0 col1\" >0.788321</td>\n",
" <td id=\"T_3ee0b_row0_col2\" class=\"data row0 col2\" >0.858893</td>\n",
" <td id=\"T_3ee0b_row0_col3\" class=\"data row0 col3\" >0.657111</td>\n",
" <td id=\"T_3ee0b_row0_col4\" class=\"data row0 col4\" >0.657157</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_3ee0b_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_3ee0b_row1_col0\" class=\"data row1 col0\" >0.837989</td>\n",
" <td id=\"T_3ee0b_row1_col1\" class=\"data row1 col1\" >0.794326</td>\n",
" <td id=\"T_3ee0b_row1_col2\" class=\"data row1 col2\" >0.866140</td>\n",
" <td id=\"T_3ee0b_row1_col3\" class=\"data row1 col3\" >0.660785</td>\n",
" <td id=\"T_3ee0b_row1_col4\" class=\"data row1 col4\" >0.661193</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x28220ffbc20>"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Died\", \"Sirvived\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Регрессия"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Загрузка данных"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>T</th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" <th>Density</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>20</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.06250</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>25</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.05979</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>35</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.05404</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" T Al2O3 TiO2 Density\n",
"0 20 0.0 0.0 1.06250\n",
"1 25 0.0 0.0 1.05979\n",
"2 35 0.0 0.0 1.05404"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>T</th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" <th>Density</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>30</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>1.05696</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>55</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>1.04158</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>25</td>\n",
" <td>0.05</td>\n",
" <td>0.0</td>\n",
" <td>1.08438</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" T Al2O3 TiO2 Density\n",
"0 30 0.00 0.0 1.05696\n",
"1 55 0.00 0.0 1.04158\n",
"2 25 0.05 0.0 1.08438"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import pandas as pd\n",
"\n",
"density_train = pd.read_csv(\"data/density/density_train.csv\", sep=\";\", decimal=\",\")\n",
"density_test = pd.read_csv(\"data/density/density_test.csv\", sep=\";\", decimal=\",\")\n",
"\n",
"display(density_train.head(3))\n",
"display(density_test.head(3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование выборок"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>T</th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>20</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>25</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>35</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" T Al2O3 TiO2\n",
"0 20 0.0 0.0\n",
"1 25 0.0 0.0\n",
"2 35 0.0 0.0"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Density</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1.06250</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1.05979</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.05404</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Density\n",
"0 1.06250\n",
"1 1.05979\n",
"2 1.05404"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>T</th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>30</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>55</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>25</td>\n",
" <td>0.05</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" T Al2O3 TiO2\n",
"0 30 0.00 0.0\n",
"1 55 0.00 0.0\n",
"2 25 0.05 0.0"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Density</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1.05696</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1.04158</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.08438</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Density\n",
"0 1.05696\n",
"1 1.04158\n",
"2 1.08438"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"density_y_train = pd.DataFrame(density_train[\"Density\"], columns=[\"Density\"])\n",
"density_train = density_train.drop([\"Density\"], axis=1)\n",
"\n",
"display(density_train.head(3))\n",
"display(density_y_train.head(3))\n",
"\n",
"density_y_test = pd.DataFrame(density_test[\"Density\"], columns=[\"Density\"])\n",
"density_test = density_test.drop([\"Density\"], axis=1)\n",
"\n",
"display(density_test.head(3))\n",
"display(density_y_test.head(3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Определение перечня алгоритмов решения задачи аппроксимации (регрессии)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
"\n",
"random_state = 9\n",
"\n",
"models = {\n",
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPRegressor(\n",
" activation=\"tanh\",\n",
" hidden_layer_sizes=(3,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение и оценка моделей с помощью различных алгоритмов"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: linear\n",
"Model: linear_poly\n",
"Model: linear_interact\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"from src.utils import run_regression\n",
"\n",
"for model_name in models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" X_train = density_train\n",
" X_test = density_test\n",
" y_train = density_y_train\n",
" y_test = density_y_test\n",
"\n",
" model = models[model_name][\"model\"]\n",
" fitted_model = model.fit(\n",
" X_train.values, density_y_train.values.ravel()\n",
" )\n",
"\n",
" models[model_name] = run_regression(fitted_model, X_train, X_test, y_train, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод результатов оценки"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_63a66_row0_col0, #T_63a66_row0_col1, #T_63a66_row4_col0 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row0_col2, #T_63a66_row7_col3 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row0_col3, #T_63a66_row1_col3, #T_63a66_row2_col3, #T_63a66_row3_col3, #T_63a66_row4_col3, #T_63a66_row7_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row1_col0, #T_63a66_row1_col1, #T_63a66_row2_col0 {\n",
" background-color: #26828e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row1_col2 {\n",
" background-color: #5901a5;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row2_col1, #T_63a66_row3_col0 {\n",
" background-color: #25838e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row2_col2 {\n",
" background-color: #6400a7;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row3_col1 {\n",
" background-color: #24868e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row3_col2 {\n",
" background-color: #7100a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row4_col1 {\n",
" background-color: #24878e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row4_col2 {\n",
" background-color: #7701a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row5_col0 {\n",
" background-color: #21908d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row5_col1 {\n",
" background-color: #21918c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row5_col2 {\n",
" background-color: #910ea3;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row5_col3 {\n",
" background-color: #d8576b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row6_col0 {\n",
" background-color: #38b977;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row6_col1 {\n",
" background-color: #3bbb75;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row6_col2 {\n",
" background-color: #c5407e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row6_col3 {\n",
" background-color: #b7318a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_63a66_row7_col0, #T_63a66_row7_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"</style>\n",
"<table id=\"T_63a66\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_63a66_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
" <th id=\"T_63a66_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
" <th id=\"T_63a66_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
" <th id=\"T_63a66_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_63a66_level0_row0\" class=\"row_heading level0 row0\" >linear_poly</th>\n",
" <td id=\"T_63a66_row0_col0\" class=\"data row0 col0\" >0.000319</td>\n",
" <td id=\"T_63a66_row0_col1\" class=\"data row0 col1\" >0.000362</td>\n",
" <td id=\"T_63a66_row0_col2\" class=\"data row0 col2\" >0.016643</td>\n",
" <td id=\"T_63a66_row0_col3\" class=\"data row0 col3\" >0.999965</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_63a66_level0_row1\" class=\"row_heading level0 row1\" >linear_interact</th>\n",
" <td id=\"T_63a66_row1_col0\" class=\"data row1 col0\" >0.001131</td>\n",
" <td id=\"T_63a66_row1_col1\" class=\"data row1 col1\" >0.001491</td>\n",
" <td id=\"T_63a66_row1_col2\" class=\"data row1 col2\" >0.033198</td>\n",
" <td id=\"T_63a66_row1_col3\" class=\"data row1 col3\" >0.999413</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_63a66_level0_row2\" class=\"row_heading level0 row2\" >linear</th>\n",
" <td id=\"T_63a66_row2_col0\" class=\"data row2 col0\" >0.002464</td>\n",
" <td id=\"T_63a66_row2_col1\" class=\"data row2 col1\" >0.003261</td>\n",
" <td id=\"T_63a66_row2_col2\" class=\"data row2 col2\" >0.049891</td>\n",
" <td id=\"T_63a66_row2_col3\" class=\"data row2 col3\" >0.997191</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_63a66_level0_row3\" class=\"row_heading level0 row3\" >random_forest</th>\n",
" <td id=\"T_63a66_row3_col0\" class=\"data row3 col0\" >0.002716</td>\n",
" <td id=\"T_63a66_row3_col1\" class=\"data row3 col1\" >0.005575</td>\n",
" <td id=\"T_63a66_row3_col2\" class=\"data row3 col2\" >0.067298</td>\n",
" <td id=\"T_63a66_row3_col3\" class=\"data row3 col3\" >0.991788</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_63a66_level0_row4\" class=\"row_heading level0 row4\" >decision_tree</th>\n",
" <td id=\"T_63a66_row4_col0\" class=\"data row4 col0\" >0.000346</td>\n",
" <td id=\"T_63a66_row4_col1\" class=\"data row4 col1\" >0.006433</td>\n",
" <td id=\"T_63a66_row4_col2\" class=\"data row4 col2\" >0.076138</td>\n",
" <td id=\"T_63a66_row4_col3\" class=\"data row4 col3\" >0.989067</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_63a66_level0_row5\" class=\"row_heading level0 row5\" >ridge</th>\n",
" <td id=\"T_63a66_row5_col0\" class=\"data row5 col0\" >0.013989</td>\n",
" <td id=\"T_63a66_row5_col1\" class=\"data row5 col1\" >0.015356</td>\n",
" <td id=\"T_63a66_row5_col2\" class=\"data row5 col2\" >0.116380</td>\n",
" <td id=\"T_63a66_row5_col3\" class=\"data row5 col3\" >0.937703</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_63a66_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
" <td id=\"T_63a66_row6_col0\" class=\"data row6 col0\" >0.053108</td>\n",
" <td id=\"T_63a66_row6_col1\" class=\"data row6 col1\" >0.056776</td>\n",
" <td id=\"T_63a66_row6_col2\" class=\"data row6 col2\" >0.217611</td>\n",
" <td id=\"T_63a66_row6_col3\" class=\"data row6 col3\" >0.148414</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_63a66_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_63a66_row7_col0\" class=\"data row7 col0\" >0.095734</td>\n",
" <td id=\"T_63a66_row7_col1\" class=\"data row7 col1\" >0.099654</td>\n",
" <td id=\"T_63a66_row7_col2\" class=\"data row7 col2\" >0.270371</td>\n",
" <td id=\"T_63a66_row7_col3\" class=\"data row7 col3\" >-1.623554</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x282215a87a0>"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод реального и \"спрогнозированного\" результата для обучающей и тестовой выборок"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Получение лучшей модели"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'linear_poly'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод для обучающей выборки"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>T</th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" <th>Density</th>\n",
" <th>DensityPred</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>20</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.06250</td>\n",
" <td>1.063174</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>25</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.05979</td>\n",
" <td>1.060117</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>35</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.05404</td>\n",
" <td>1.053941</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>40</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.05103</td>\n",
" <td>1.050822</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>45</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.04794</td>\n",
" <td>1.047683</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" T Al2O3 TiO2 Density DensityPred\n",
"0 20 0.0 0.0 1.06250 1.063174\n",
"1 25 0.0 0.0 1.05979 1.060117\n",
"2 35 0.0 0.0 1.05404 1.053941\n",
"3 40 0.0 0.0 1.05103 1.050822\n",
"4 45 0.0 0.0 1.04794 1.047683"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.concat(\n",
" [\n",
" density_train,\n",
" density_y_train,\n",
" pd.Series(\n",
" models[best_model][\"train_preds\"],\n",
" index=density_y_train.index,\n",
" name=\"DensityPred\",\n",
" ),\n",
" ],\n",
" axis=1,\n",
").head(5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод для тестовой выборки"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>T</th>\n",
" <th>Al2O3</th>\n",
" <th>TiO2</th>\n",
" <th>Density</th>\n",
" <th>DensityPred</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>30</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>1.05696</td>\n",
" <td>1.057040</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>55</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>1.04158</td>\n",
" <td>1.041341</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>25</td>\n",
" <td>0.05</td>\n",
" <td>0.0</td>\n",
" <td>1.08438</td>\n",
" <td>1.084063</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>30</td>\n",
" <td>0.05</td>\n",
" <td>0.0</td>\n",
" <td>1.08112</td>\n",
" <td>1.080764</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>35</td>\n",
" <td>0.05</td>\n",
" <td>0.0</td>\n",
" <td>1.07781</td>\n",
" <td>1.077444</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" T Al2O3 TiO2 Density DensityPred\n",
"0 30 0.00 0.0 1.05696 1.057040\n",
"1 55 0.00 0.0 1.04158 1.041341\n",
"2 25 0.05 0.0 1.08438 1.084063\n",
"3 30 0.05 0.0 1.08112 1.080764\n",
"4 35 0.05 0.0 1.07781 1.077444"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.concat(\n",
" [\n",
" density_test,\n",
" density_y_test,\n",
" pd.Series(\n",
" models[best_model][\"preds\"],\n",
" index=density_y_test.index,\n",
" name=\"DensityPred\",\n",
" ),\n",
" ],\n",
" axis=1,\n",
").head(5)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}