MII_Salin_Oleg_PIbd-33/lec4.ipynb
2024-11-08 22:37:34 +04:00

2408 lines
234 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Загрузка набора данных"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"random_state=9\n",
"\n",
"df = pd.read_csv(\"data/Medical_insurance.csv\", index_col=False)\n",
"\n",
"df[\"smoker\"] = df[\"smoker\"].apply(lambda x: 1 if x == \"yes\" else 0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
"\n",
"Целевой признак -- Survived"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>sex</th>\n",
" <th>bmi</th>\n",
" <th>children</th>\n",
" <th>smoker</th>\n",
" <th>region</th>\n",
" <th>charges</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>671</th>\n",
" <td>29</td>\n",
" <td>female</td>\n",
" <td>31.160</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>northeast</td>\n",
" <td>3943.59540</td>\n",
" </tr>\n",
" <tr>\n",
" <th>808</th>\n",
" <td>18</td>\n",
" <td>male</td>\n",
" <td>30.140</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>southeast</td>\n",
" <td>1131.50660</td>\n",
" </tr>\n",
" <tr>\n",
" <th>795</th>\n",
" <td>27</td>\n",
" <td>male</td>\n",
" <td>28.500</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>northwest</td>\n",
" <td>18310.74200</td>\n",
" </tr>\n",
" <tr>\n",
" <th>576</th>\n",
" <td>22</td>\n",
" <td>male</td>\n",
" <td>26.840</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>southeast</td>\n",
" <td>1664.99960</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1232</th>\n",
" <td>54</td>\n",
" <td>female</td>\n",
" <td>24.605</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>northwest</td>\n",
" <td>12479.70895</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>105</th>\n",
" <td>20</td>\n",
" <td>male</td>\n",
" <td>28.025</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>northwest</td>\n",
" <td>17560.37975</td>\n",
" </tr>\n",
" <tr>\n",
" <th>461</th>\n",
" <td>42</td>\n",
" <td>male</td>\n",
" <td>30.000</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>southwest</td>\n",
" <td>22144.03200</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2650</th>\n",
" <td>49</td>\n",
" <td>female</td>\n",
" <td>33.345</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>northeast</td>\n",
" <td>10370.91255</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1674</th>\n",
" <td>59</td>\n",
" <td>female</td>\n",
" <td>36.765</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>northeast</td>\n",
" <td>47896.79135</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2689</th>\n",
" <td>43</td>\n",
" <td>male</td>\n",
" <td>27.800</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>southwest</td>\n",
" <td>37829.72420</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2217 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" age sex bmi children smoker region charges\n",
"671 29 female 31.160 0 0 northeast 3943.59540\n",
"808 18 male 30.140 0 0 southeast 1131.50660\n",
"795 27 male 28.500 0 1 northwest 18310.74200\n",
"576 22 male 26.840 0 0 southeast 1664.99960\n",
"1232 54 female 24.605 3 0 northwest 12479.70895\n",
"... ... ... ... ... ... ... ...\n",
"105 20 male 28.025 1 1 northwest 17560.37975\n",
"461 42 male 30.000 0 1 southwest 22144.03200\n",
"2650 49 female 33.345 2 0 northeast 10370.91255\n",
"1674 59 female 36.765 1 1 northeast 47896.79135\n",
"2689 43 male 27.800 0 1 southwest 37829.72420\n",
"\n",
"[2217 rows x 7 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>smoker</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>671</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>808</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>795</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>576</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1232</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>105</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>461</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2650</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1674</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2689</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2217 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" smoker\n",
"671 0\n",
"808 0\n",
"795 1\n",
"576 0\n",
"1232 0\n",
"... ...\n",
"105 1\n",
"461 1\n",
"2650 0\n",
"1674 1\n",
"2689 1\n",
"\n",
"[2217 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>sex</th>\n",
" <th>bmi</th>\n",
" <th>children</th>\n",
" <th>smoker</th>\n",
" <th>region</th>\n",
" <th>charges</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>124</th>\n",
" <td>47</td>\n",
" <td>female</td>\n",
" <td>33.915</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>northwest</td>\n",
" <td>10115.00885</td>\n",
" </tr>\n",
" <tr>\n",
" <th>778</th>\n",
" <td>35</td>\n",
" <td>male</td>\n",
" <td>34.320</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>southeast</td>\n",
" <td>5934.37980</td>\n",
" </tr>\n",
" <tr>\n",
" <th>372</th>\n",
" <td>42</td>\n",
" <td>female</td>\n",
" <td>33.155</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>northeast</td>\n",
" <td>7639.41745</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1969</th>\n",
" <td>32</td>\n",
" <td>female</td>\n",
" <td>23.650</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>southeast</td>\n",
" <td>17626.23951</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2522</th>\n",
" <td>44</td>\n",
" <td>female</td>\n",
" <td>25.000</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>southwest</td>\n",
" <td>7623.51800</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>908</th>\n",
" <td>63</td>\n",
" <td>male</td>\n",
" <td>39.800</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>southwest</td>\n",
" <td>15170.06900</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1203</th>\n",
" <td>51</td>\n",
" <td>male</td>\n",
" <td>32.300</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>northeast</td>\n",
" <td>9964.06000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>45</th>\n",
" <td>55</td>\n",
" <td>male</td>\n",
" <td>37.300</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>southwest</td>\n",
" <td>20630.28351</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2669</th>\n",
" <td>18</td>\n",
" <td>male</td>\n",
" <td>30.030</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>southeast</td>\n",
" <td>1720.35370</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1230</th>\n",
" <td>52</td>\n",
" <td>male</td>\n",
" <td>34.485</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>northwest</td>\n",
" <td>60021.39897</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>555 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" age sex bmi children smoker region charges\n",
"124 47 female 33.915 3 0 northwest 10115.00885\n",
"778 35 male 34.320 3 0 southeast 5934.37980\n",
"372 42 female 33.155 1 0 northeast 7639.41745\n",
"1969 32 female 23.650 1 0 southeast 17626.23951\n",
"2522 44 female 25.000 1 0 southwest 7623.51800\n",
"... ... ... ... ... ... ... ...\n",
"908 63 male 39.800 3 0 southwest 15170.06900\n",
"1203 51 male 32.300 1 0 northeast 9964.06000\n",
"45 55 male 37.300 0 0 southwest 20630.28351\n",
"2669 18 male 30.030 1 0 southeast 1720.35370\n",
"1230 52 male 34.485 3 1 northwest 60021.39897\n",
"\n",
"[555 rows x 7 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>smoker</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>124</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>778</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>372</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1969</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2522</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>908</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1203</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>45</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2669</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1230</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>555 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" smoker\n",
"124 0\n",
"778 0\n",
"372 0\n",
"1969 0\n",
"2522 0\n",
"... ...\n",
"908 0\n",
"1203 0\n",
"45 0\n",
"2669 0\n",
"1230 1\n",
"\n",
"[555 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from utils import split_stratified_into_train_val_test\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"smoker\", target_colname=\"smoker\",frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование конвейера для классификации данных\n",
"\n",
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
"\n",
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
"\n",
"features_preprocessing -- трансформер для предобработки признаков\n",
"\n",
"features_engineering -- трансформер для конструирования признаков\n",
"\n",
"drop_columns -- трансформер для удаления колонок\n",
"\n",
"features_postprocessing -- трансформер для унитарного кодирования новых признаков\n",
"\n",
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков\n",
"\n",
"Конвейер выполняется последовательно.\n",
"\n",
"Трансформер выполняет параллельно для указанного набора колонок.\n",
"\n",
"Документация: \n",
"\n",
"https://scikit-learn.org/1.5/api/sklearn.pipeline.html\n",
"\n",
"https://scikit-learn.org/1.5/modules/generated/sklearn.compose.ColumnTransformer.html#sklearn.compose.ColumnTransformer"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"columns_to_drop = [\"smoker\"]\n",
"num_columns = [\n",
" column\n",
" for column in df.columns\n",
" if df[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in df.columns\n",
" if df[column].dtype == \"object\"\n",
"]\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Демонстрация работы конвейера для предобработки данных при классификации"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>bmi</th>\n",
" <th>children</th>\n",
" <th>charges</th>\n",
" <th>sex_male</th>\n",
" <th>region_northwest</th>\n",
" <th>region_southeast</th>\n",
" <th>region_southwest</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>671</th>\n",
" <td>-0.730722</td>\n",
" <td>0.085028</td>\n",
" <td>-0.907368</td>\n",
" <td>-0.769241</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>808</th>\n",
" <td>-1.513302</td>\n",
" <td>-0.081153</td>\n",
" <td>-0.907368</td>\n",
" <td>-0.999824</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>795</th>\n",
" <td>-0.873009</td>\n",
" <td>-0.348348</td>\n",
" <td>-0.907368</td>\n",
" <td>0.408827</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>576</th>\n",
" <td>-1.228727</td>\n",
" <td>-0.618800</td>\n",
" <td>-0.907368</td>\n",
" <td>-0.956079</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1232</th>\n",
" <td>1.047868</td>\n",
" <td>-0.982934</td>\n",
" <td>1.555858</td>\n",
" <td>-0.069302</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>105</th>\n",
" <td>-1.371015</td>\n",
" <td>-0.425736</td>\n",
" <td>-0.086293</td>\n",
" <td>0.347299</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>461</th>\n",
" <td>0.194145</td>\n",
" <td>-0.103963</td>\n",
" <td>-0.907368</td>\n",
" <td>0.723147</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2650</th>\n",
" <td>0.692150</td>\n",
" <td>0.441016</td>\n",
" <td>0.734783</td>\n",
" <td>-0.242218</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1674</th>\n",
" <td>1.403586</td>\n",
" <td>0.998214</td>\n",
" <td>-0.086293</td>\n",
" <td>2.834804</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2689</th>\n",
" <td>0.265288</td>\n",
" <td>-0.462394</td>\n",
" <td>-0.907368</td>\n",
" <td>2.009332</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2217 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" age bmi children charges sex_male region_northwest \\\n",
"671 -0.730722 0.085028 -0.907368 -0.769241 0.0 0.0 \n",
"808 -1.513302 -0.081153 -0.907368 -0.999824 1.0 0.0 \n",
"795 -0.873009 -0.348348 -0.907368 0.408827 1.0 1.0 \n",
"576 -1.228727 -0.618800 -0.907368 -0.956079 1.0 0.0 \n",
"1232 1.047868 -0.982934 1.555858 -0.069302 0.0 1.0 \n",
"... ... ... ... ... ... ... \n",
"105 -1.371015 -0.425736 -0.086293 0.347299 1.0 1.0 \n",
"461 0.194145 -0.103963 -0.907368 0.723147 1.0 0.0 \n",
"2650 0.692150 0.441016 0.734783 -0.242218 0.0 0.0 \n",
"1674 1.403586 0.998214 -0.086293 2.834804 0.0 0.0 \n",
"2689 0.265288 -0.462394 -0.907368 2.009332 1.0 0.0 \n",
"\n",
" region_southeast region_southwest \n",
"671 0.0 0.0 \n",
"808 1.0 0.0 \n",
"795 0.0 0.0 \n",
"576 1.0 0.0 \n",
"1232 0.0 0.0 \n",
"... ... ... \n",
"105 0.0 0.0 \n",
"461 0.0 1.0 \n",
"2650 0.0 0.0 \n",
"1674 0.0 0.0 \n",
"2689 0.0 1.0 \n",
"\n",
"[2217 rows x 8 columns]"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование набора моделей для классификации\n",
"\n",
"logistic -- логистическая регрессия\n",
"\n",
"ridge -- гребневая регрессия\n",
"\n",
"decision_tree -- дерево решений\n",
"\n",
"knn -- k-ближайших соседей\n",
"\n",
"naive_bayes -- наивный Байесовский классификатор\n",
"\n",
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
"\n",
"random_forest -- метод случайного леса (набор деревьев решений)\n",
"\n",
"mlp -- многослойный персептрон (нейронная сеть)\n",
"\n",
"Документация: https://scikit-learn.org/1.5/supervised_learning.html"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение моделей на обучающем наборе данных и оценка на тестовом"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import numpy as np # type: ignore\n",
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Сводная таблица оценок качества для использованных моделей классификации\n",
"\n",
"Документация: https://scikit-learn.org/1.5/modules/model_evaluation.html"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Матрица неточностей"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"non smoker\", \"smoker\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_98a81_row0_col0, #T_98a81_row0_col1, #T_98a81_row0_col2, #T_98a81_row0_col3, #T_98a81_row2_col2, #T_98a81_row3_col2, #T_98a81_row3_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row0_col4, #T_98a81_row0_col5, #T_98a81_row0_col6, #T_98a81_row0_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row1_col0 {\n",
" background-color: #86d549;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row1_col1 {\n",
" background-color: #7ad151;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row1_col2 {\n",
" background-color: #a5db36;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row1_col3 {\n",
" background-color: #98d83e;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row1_col4 {\n",
" background-color: #d45270;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row1_col5, #T_98a81_row2_col5 {\n",
" background-color: #d04d73;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row1_col6 {\n",
" background-color: #d5546e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row1_col7, #T_98a81_row2_col7 {\n",
" background-color: #d24f71;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row2_col0 {\n",
" background-color: #a2da37;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row2_col1 {\n",
" background-color: #6ccd5a;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row2_col3 {\n",
" background-color: #9dd93b;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row2_col4, #T_98a81_row2_col6 {\n",
" background-color: #d9586a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row3_col0 {\n",
" background-color: #1e9b8a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row3_col1 {\n",
" background-color: #23888e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row3_col4 {\n",
" background-color: #b42e8d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row3_col5, #T_98a81_row5_col6 {\n",
" background-color: #b52f8c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row3_col6 {\n",
" background-color: #bf3984;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row3_col7, #T_98a81_row4_col6 {\n",
" background-color: #c13b82;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row4_col0 {\n",
" background-color: #35b779;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row4_col1 {\n",
" background-color: #1f9f88;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row4_col2 {\n",
" background-color: #81d34d;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row4_col3 {\n",
" background-color: #69cd5b;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row4_col4 {\n",
" background-color: #b83289;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row4_col5, #T_98a81_row5_col7 {\n",
" background-color: #aa2395;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row4_col7 {\n",
" background-color: #b6308b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row5_col0 {\n",
" background-color: #21a585;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row5_col1 {\n",
" background-color: #21918c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row5_col2 {\n",
" background-color: #73d056;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row5_col3 {\n",
" background-color: #54c568;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row5_col4 {\n",
" background-color: #a82296;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row5_col5 {\n",
" background-color: #99159f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row6_col0, #T_98a81_row6_col1, #T_98a81_row7_col2, #T_98a81_row7_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row6_col2 {\n",
" background-color: #22a785;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row6_col3 {\n",
" background-color: #23a983;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row6_col4 {\n",
" background-color: #5c01a6;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row6_col5 {\n",
" background-color: #6c00a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row6_col6 {\n",
" background-color: #7501a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row6_col7 {\n",
" background-color: #8104a7;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row7_col0 {\n",
" background-color: #3bbb75;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row7_col1 {\n",
" background-color: #34b679;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row7_col4, #T_98a81_row7_col5, #T_98a81_row7_col6, #T_98a81_row7_col7 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_98a81\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_98a81_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_98a81_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_98a81_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_98a81_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_98a81_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_98a81_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_98a81_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_98a81_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_98a81_level0_row0\" class=\"row_heading level0 row0\" >gradient_boosting</th>\n",
" <td id=\"T_98a81_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_98a81_row0_col1\" class=\"data row0 col1\" >0.982609</td>\n",
" <td id=\"T_98a81_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_98a81_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_98a81_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_98a81_row0_col5\" class=\"data row0 col5\" >0.996396</td>\n",
" <td id=\"T_98a81_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_98a81_row0_col7\" class=\"data row0 col7\" >0.991228</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_98a81_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
" <td id=\"T_98a81_row1_col0\" class=\"data row1 col0\" >0.976035</td>\n",
" <td id=\"T_98a81_row1_col1\" class=\"data row1 col1\" >0.956522</td>\n",
" <td id=\"T_98a81_row1_col2\" class=\"data row1 col2\" >0.993348</td>\n",
" <td id=\"T_98a81_row1_col3\" class=\"data row1 col3\" >0.973451</td>\n",
" <td id=\"T_98a81_row1_col4\" class=\"data row1 col4\" >0.993685</td>\n",
" <td id=\"T_98a81_row1_col5\" class=\"data row1 col5\" >0.985586</td>\n",
" <td id=\"T_98a81_row1_col6\" class=\"data row1 col6\" >0.984615</td>\n",
" <td id=\"T_98a81_row1_col7\" class=\"data row1 col7\" >0.964912</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_98a81_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
" <td id=\"T_98a81_row2_col0\" class=\"data row2 col0\" >0.995585</td>\n",
" <td id=\"T_98a81_row2_col1\" class=\"data row2 col1\" >0.948718</td>\n",
" <td id=\"T_98a81_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_98a81_row2_col3\" class=\"data row2 col3\" >0.982301</td>\n",
" <td id=\"T_98a81_row2_col4\" class=\"data row2 col4\" >0.999098</td>\n",
" <td id=\"T_98a81_row2_col5\" class=\"data row2 col5\" >0.985586</td>\n",
" <td id=\"T_98a81_row2_col6\" class=\"data row2 col6\" >0.997788</td>\n",
" <td id=\"T_98a81_row2_col7\" class=\"data row2 col7\" >0.965217</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_98a81_level0_row3\" class=\"row_heading level0 row3\" >ridge</th>\n",
" <td id=\"T_98a81_row3_col0\" class=\"data row3 col0\" >0.846154</td>\n",
" <td id=\"T_98a81_row3_col1\" class=\"data row3 col1\" >0.837037</td>\n",
" <td id=\"T_98a81_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_98a81_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_98a81_row3_col4\" class=\"data row3 col4\" >0.963013</td>\n",
" <td id=\"T_98a81_row3_col5\" class=\"data row3 col5\" >0.960360</td>\n",
" <td id=\"T_98a81_row3_col6\" class=\"data row3 col6\" >0.916667</td>\n",
" <td id=\"T_98a81_row3_col7\" class=\"data row3 col7\" >0.911290</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_98a81_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
" <td id=\"T_98a81_row4_col0\" class=\"data row4 col0\" >0.903846</td>\n",
" <td id=\"T_98a81_row4_col1\" class=\"data row4 col1\" >0.870690</td>\n",
" <td id=\"T_98a81_row4_col2\" class=\"data row4 col2\" >0.937916</td>\n",
" <td id=\"T_98a81_row4_col3\" class=\"data row4 col3\" >0.893805</td>\n",
" <td id=\"T_98a81_row4_col4\" class=\"data row4 col4\" >0.967073</td>\n",
" <td id=\"T_98a81_row4_col5\" class=\"data row4 col5\" >0.951351</td>\n",
" <td id=\"T_98a81_row4_col6\" class=\"data row4 col6\" >0.920566</td>\n",
" <td id=\"T_98a81_row4_col7\" class=\"data row4 col7\" >0.882096</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_98a81_level0_row5\" class=\"row_heading level0 row5\" >logistic</th>\n",
" <td id=\"T_98a81_row5_col0\" class=\"data row5 col0\" >0.867368</td>\n",
" <td id=\"T_98a81_row5_col1\" class=\"data row5 col1\" >0.849558</td>\n",
" <td id=\"T_98a81_row5_col2\" class=\"data row5 col2\" >0.913525</td>\n",
" <td id=\"T_98a81_row5_col3\" class=\"data row5 col3\" >0.849558</td>\n",
" <td id=\"T_98a81_row5_col4\" class=\"data row5 col4\" >0.953992</td>\n",
" <td id=\"T_98a81_row5_col5\" class=\"data row5 col5\" >0.938739</td>\n",
" <td id=\"T_98a81_row5_col6\" class=\"data row5 col6\" >0.889849</td>\n",
" <td id=\"T_98a81_row5_col7\" class=\"data row5 col7\" >0.849558</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_98a81_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
" <td id=\"T_98a81_row6_col0\" class=\"data row6 col0\" >0.794045</td>\n",
" <td id=\"T_98a81_row6_col1\" class=\"data row6 col1\" >0.824742</td>\n",
" <td id=\"T_98a81_row6_col2\" class=\"data row6 col2\" >0.709534</td>\n",
" <td id=\"T_98a81_row6_col3\" class=\"data row6 col3\" >0.707965</td>\n",
" <td id=\"T_98a81_row6_col4\" class=\"data row6 col4\" >0.903473</td>\n",
" <td id=\"T_98a81_row6_col5\" class=\"data row6 col5\" >0.909910</td>\n",
" <td id=\"T_98a81_row6_col6\" class=\"data row6 col6\" >0.749415</td>\n",
" <td id=\"T_98a81_row6_col7\" class=\"data row6 col7\" >0.761905</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_98a81_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_98a81_row7_col0\" class=\"data row7 col0\" >0.910112</td>\n",
" <td id=\"T_98a81_row7_col1\" class=\"data row7 col1\" >0.907692</td>\n",
" <td id=\"T_98a81_row7_col2\" class=\"data row7 col2\" >0.538803</td>\n",
" <td id=\"T_98a81_row7_col3\" class=\"data row7 col3\" >0.522124</td>\n",
" <td id=\"T_98a81_row7_col4\" class=\"data row7 col4\" >0.895354</td>\n",
" <td id=\"T_98a81_row7_col5\" class=\"data row7 col5\" >0.891892</td>\n",
" <td id=\"T_98a81_row7_col6\" class=\"data row7 col6\" >0.676880</td>\n",
" <td id=\"T_98a81_row7_col7\" class=\"data row7 col7\" >0.662921</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x2464b71fd10>"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_fd909_row0_col0, #T_fd909_row2_col0 {\n",
" background-color: #8bd646;\n",
" color: #000000;\n",
"}\n",
"#T_fd909_row0_col1, #T_fd909_row2_col1 {\n",
" background-color: #90d743;\n",
" color: #000000;\n",
"}\n",
"#T_fd909_row0_col2, #T_fd909_row1_col3, #T_fd909_row1_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row0_col3, #T_fd909_row2_col3 {\n",
" background-color: #d24f71;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row0_col4, #T_fd909_row2_col4 {\n",
" background-color: #d14e72;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row1_col0, #T_fd909_row1_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_fd909_row1_col2 {\n",
" background-color: #d7566c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row2_col2 {\n",
" background-color: #cf4c74;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row3_col0, #T_fd909_row5_col1 {\n",
" background-color: #3bbb75;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row3_col1 {\n",
" background-color: #50c46a;\n",
" color: #000000;\n",
"}\n",
"#T_fd909_row3_col2 {\n",
" background-color: #bc3587;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row3_col3 {\n",
" background-color: #b32c8e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row3_col4 {\n",
" background-color: #b02991;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row4_col0 {\n",
" background-color: #4ec36b;\n",
" color: #000000;\n",
"}\n",
"#T_fd909_row4_col1 {\n",
" background-color: #65cb5e;\n",
" color: #000000;\n",
"}\n",
"#T_fd909_row4_col2 {\n",
" background-color: #99159f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row4_col3 {\n",
" background-color: #be3885;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row4_col4 {\n",
" background-color: #bd3786;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row5_col0 {\n",
" background-color: #29af7f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row5_col2 {\n",
" background-color: #9613a1;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row5_col3 {\n",
" background-color: #a62098;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row5_col4 {\n",
" background-color: #a01a9c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row6_col0 {\n",
" background-color: #20928c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row6_col1 {\n",
" background-color: #1fa088;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row6_col2 {\n",
" background-color: #7401a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row6_col3 {\n",
" background-color: #7d03a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row6_col4 {\n",
" background-color: #7201a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row7_col0, #T_fd909_row7_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row7_col2, #T_fd909_row7_col3, #T_fd909_row7_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_fd909\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_fd909_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_fd909_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_fd909_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_fd909_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_fd909_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_fd909_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
" <td id=\"T_fd909_row0_col0\" class=\"data row0 col0\" >0.985586</td>\n",
" <td id=\"T_fd909_row0_col1\" class=\"data row0 col1\" >0.965217</td>\n",
" <td id=\"T_fd909_row0_col2\" class=\"data row0 col2\" >0.999039</td>\n",
" <td id=\"T_fd909_row0_col3\" class=\"data row0 col3\" >0.956130</td>\n",
" <td id=\"T_fd909_row0_col4\" class=\"data row0 col4\" >0.956360</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fd909_level0_row1\" class=\"row_heading level0 row1\" >gradient_boosting</th>\n",
" <td id=\"T_fd909_row1_col0\" class=\"data row1 col0\" >0.996396</td>\n",
" <td id=\"T_fd909_row1_col1\" class=\"data row1 col1\" >0.991228</td>\n",
" <td id=\"T_fd909_row1_col2\" class=\"data row1 col2\" >0.998118</td>\n",
" <td id=\"T_fd909_row1_col3\" class=\"data row1 col3\" >0.988961</td>\n",
" <td id=\"T_fd909_row1_col4\" class=\"data row1 col4\" >0.989021</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fd909_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_fd909_row2_col0\" class=\"data row2 col0\" >0.985586</td>\n",
" <td id=\"T_fd909_row2_col1\" class=\"data row2 col1\" >0.964912</td>\n",
" <td id=\"T_fd909_row2_col2\" class=\"data row2 col2\" >0.995745</td>\n",
" <td id=\"T_fd909_row2_col3\" class=\"data row2 col3\" >0.955843</td>\n",
" <td id=\"T_fd909_row2_col4\" class=\"data row2 col4\" >0.955901</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fd909_level0_row3\" class=\"row_heading level0 row3\" >knn</th>\n",
" <td id=\"T_fd909_row3_col0\" class=\"data row3 col0\" >0.951351</td>\n",
" <td id=\"T_fd909_row3_col1\" class=\"data row3 col1\" >0.882096</td>\n",
" <td id=\"T_fd909_row3_col2\" class=\"data row3 col2\" >0.990049</td>\n",
" <td id=\"T_fd909_row3_col3\" class=\"data row3 col3\" >0.851456</td>\n",
" <td id=\"T_fd909_row3_col4\" class=\"data row3 col4\" >0.851572</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fd909_level0_row4\" class=\"row_heading level0 row4\" >ridge</th>\n",
" <td id=\"T_fd909_row4_col0\" class=\"data row4 col0\" >0.960360</td>\n",
" <td id=\"T_fd909_row4_col1\" class=\"data row4 col1\" >0.911290</td>\n",
" <td id=\"T_fd909_row4_col2\" class=\"data row4 col2\" >0.982001</td>\n",
" <td id=\"T_fd909_row4_col3\" class=\"data row4 col3\" >0.886026</td>\n",
" <td id=\"T_fd909_row4_col4\" class=\"data row4 col4\" >0.891838</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fd909_level0_row5\" class=\"row_heading level0 row5\" >logistic</th>\n",
" <td id=\"T_fd909_row5_col0\" class=\"data row5 col0\" >0.938739</td>\n",
" <td id=\"T_fd909_row5_col1\" class=\"data row5 col1\" >0.849558</td>\n",
" <td id=\"T_fd909_row5_col2\" class=\"data row5 col2\" >0.981520</td>\n",
" <td id=\"T_fd909_row5_col3\" class=\"data row5 col3\" >0.811096</td>\n",
" <td id=\"T_fd909_row5_col4\" class=\"data row5 col4\" >0.811096</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fd909_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
" <td id=\"T_fd909_row6_col0\" class=\"data row6 col0\" >0.909910</td>\n",
" <td id=\"T_fd909_row6_col1\" class=\"data row6 col1\" >0.761905</td>\n",
" <td id=\"T_fd909_row6_col2\" class=\"data row6 col2\" >0.974813</td>\n",
" <td id=\"T_fd909_row6_col3\" class=\"data row6 col3\" >0.706746</td>\n",
" <td id=\"T_fd909_row6_col4\" class=\"data row6 col4\" >0.709879</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fd909_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_fd909_row7_col0\" class=\"data row7 col0\" >0.891892</td>\n",
" <td id=\"T_fd909_row7_col1\" class=\"data row7 col1\" >0.662921</td>\n",
" <td id=\"T_fd909_row7_col2\" class=\"data row7 col2\" >0.968086</td>\n",
" <td id=\"T_fd909_row7_col3\" class=\"data row7 col3\" >0.604043</td>\n",
" <td id=\"T_fd909_row7_col4\" class=\"data row7 col4\" >0.636838</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x24665055160>"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'gradient_boosting'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 2'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>Predicted</th>\n",
" <th>sex</th>\n",
" <th>bmi</th>\n",
" <th>children</th>\n",
" <th>smoker</th>\n",
" <th>region</th>\n",
" <th>charges</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>583</th>\n",
" <td>32</td>\n",
" <td>1</td>\n",
" <td>female</td>\n",
" <td>23.65</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>southeast</td>\n",
" <td>17626.23951</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1969</th>\n",
" <td>32</td>\n",
" <td>1</td>\n",
" <td>female</td>\n",
" <td>23.65</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>southeast</td>\n",
" <td>17626.23951</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age Predicted sex bmi children smoker region charges\n",
"583 32 1 female 23.65 1 0 southeast 17626.23951\n",
"1969 32 1 female 23.65 1 0 southeast 17626.23951"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"smoker\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Пример использования обученной модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"age 32\n",
"sex female\n",
"bmi 23.65\n",
"children 1\n",
"smoker 0\n",
"region southeast\n",
"charges 17626.23951\n",
"Name: 1969, dtype: object\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>sex</th>\n",
" <th>bmi</th>\n",
" <th>children</th>\n",
" <th>smoker</th>\n",
" <th>region</th>\n",
" <th>charges</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1969</th>\n",
" <td>32</td>\n",
" <td>female</td>\n",
" <td>23.65</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>southeast</td>\n",
" <td>17626.23951</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age sex bmi children smoker region charges\n",
"1969 32 female 23.65 1 0 southeast 17626.23951"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>bmi</th>\n",
" <th>children</th>\n",
" <th>charges</th>\n",
" <th>sex_male</th>\n",
" <th>region_northwest</th>\n",
" <th>region_southeast</th>\n",
" <th>region_southwest</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1969</th>\n",
" <td>-0.517291</td>\n",
" <td>-1.138526</td>\n",
" <td>-0.086293</td>\n",
" <td>0.3527</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age bmi children charges sex_male region_northwest \\\n",
"1969 -0.517291 -1.138526 -0.086293 0.3527 0.0 0.0 \n",
"\n",
" region_southeast region_southwest \n",
"1969 1.0 0.0 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: 1 (proba: [0.01087081 0.98912919])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 0'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 1969\n",
"print(X_test.loc[example_id, :])\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Подбор гиперпараметров методом поиска по сетке\n",
"\n",
"https://www.kaggle.com/code/sociopath00/random-forest-using-gridsearchcv\n",
"\n",
"https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\frenk\\OneDrive\\Рабочий стол\\MII_Salin_Oleg_PIbd-33\\.venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"data": {
"text/plain": [
"{'model__criterion': 'entropy',\n",
" 'model__max_depth': 10,\n",
" 'model__max_features': 'log2',\n",
" 'model__n_estimators': 250}"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"optimized_model_type = \"random_forest\"\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
" \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10],\n",
" \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"optimized_model = ensemble.RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=7,\n",
" max_features=\"sqrt\",\n",
" n_estimators=30,\n",
")\n",
"\n",
"result = {}\n",
"\n",
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_b2ee5_row0_col0, #T_b2ee5_row0_col1, #T_b2ee5_row0_col2, #T_b2ee5_row0_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_b2ee5_row0_col4, #T_b2ee5_row0_col5, #T_b2ee5_row0_col6, #T_b2ee5_row0_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b2ee5_row1_col0, #T_b2ee5_row1_col1, #T_b2ee5_row1_col2, #T_b2ee5_row1_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_b2ee5_row1_col4, #T_b2ee5_row1_col5, #T_b2ee5_row1_col6, #T_b2ee5_row1_col7 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_b2ee5\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_b2ee5_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_b2ee5_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_b2ee5_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_b2ee5_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_b2ee5_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_b2ee5_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_b2ee5_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_b2ee5_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" <th class=\"blank col7\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_b2ee5_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_b2ee5_row0_col0\" class=\"data row0 col0\" >0.995585</td>\n",
" <td id=\"T_b2ee5_row0_col1\" class=\"data row0 col1\" >0.948718</td>\n",
" <td id=\"T_b2ee5_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_b2ee5_row0_col3\" class=\"data row0 col3\" >0.982301</td>\n",
" <td id=\"T_b2ee5_row0_col4\" class=\"data row0 col4\" >0.999098</td>\n",
" <td id=\"T_b2ee5_row0_col5\" class=\"data row0 col5\" >0.985586</td>\n",
" <td id=\"T_b2ee5_row0_col6\" class=\"data row0 col6\" >0.997788</td>\n",
" <td id=\"T_b2ee5_row0_col7\" class=\"data row0 col7\" >0.965217</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_b2ee5_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_b2ee5_row1_col0\" class=\"data row1 col0\" >0.971800</td>\n",
" <td id=\"T_b2ee5_row1_col1\" class=\"data row1 col1\" >0.923077</td>\n",
" <td id=\"T_b2ee5_row1_col2\" class=\"data row1 col2\" >0.993348</td>\n",
" <td id=\"T_b2ee5_row1_col3\" class=\"data row1 col3\" >0.955752</td>\n",
" <td id=\"T_b2ee5_row1_col4\" class=\"data row1 col4\" >0.992783</td>\n",
" <td id=\"T_b2ee5_row1_col5\" class=\"data row1 col5\" >0.974775</td>\n",
" <td id=\"T_b2ee5_row1_col6\" class=\"data row1 col6\" >0.982456</td>\n",
" <td id=\"T_b2ee5_row1_col7\" class=\"data row1 col7\" >0.939130</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x24665e4e7b0>"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_79612_row0_col0, #T_79612_row0_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_79612_row0_col2, #T_79612_row0_col3, #T_79612_row0_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_79612_row1_col0, #T_79612_row1_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_79612_row1_col2, #T_79612_row1_col3, #T_79612_row1_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_79612\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_79612_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_79612_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_79612_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_79612_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_79612_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_79612_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_79612_row0_col0\" class=\"data row0 col0\" >0.985586</td>\n",
" <td id=\"T_79612_row0_col1\" class=\"data row0 col1\" >0.965217</td>\n",
" <td id=\"T_79612_row0_col2\" class=\"data row0 col2\" >0.999039</td>\n",
" <td id=\"T_79612_row0_col3\" class=\"data row0 col3\" >0.956130</td>\n",
" <td id=\"T_79612_row0_col4\" class=\"data row0 col4\" >0.956360</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_79612_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_79612_row1_col0\" class=\"data row1 col0\" >0.974775</td>\n",
" <td id=\"T_79612_row1_col1\" class=\"data row1 col1\" >0.939130</td>\n",
" <td id=\"T_79612_row1_col2\" class=\"data row1 col2\" >0.996276</td>\n",
" <td id=\"T_79612_row1_col3\" class=\"data row1 col3\" >0.923227</td>\n",
" <td id=\"T_79612_row1_col4\" class=\"data row1 col4\" >0.923450</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x24665e4c980>"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"smokers\", \"non smokers\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}