{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Загрузка данных"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" bmi | \n",
" children | \n",
" charges | \n",
" sex_male | \n",
" region_northwest | \n",
" region_southeast | \n",
" region_southwest | \n",
" smoker_yes | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 19 | \n",
" 27.900 | \n",
" 0 | \n",
" 16884.92400 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 1.0 | \n",
"
\n",
" \n",
" 1 | \n",
" 18 | \n",
" 33.770 | \n",
" 1 | \n",
" 1725.55230 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 2 | \n",
" 28 | \n",
" 33.000 | \n",
" 3 | \n",
" 4449.46200 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 3 | \n",
" 33 | \n",
" 22.705 | \n",
" 0 | \n",
" 21984.47061 | \n",
" 1.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 4 | \n",
" 32 | \n",
" 28.880 | \n",
" 0 | \n",
" 3866.85520 | \n",
" 1.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 2767 | \n",
" 47 | \n",
" 45.320 | \n",
" 1 | \n",
" 8569.86180 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 2768 | \n",
" 21 | \n",
" 34.600 | \n",
" 0 | \n",
" 2020.17700 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 2769 | \n",
" 19 | \n",
" 26.030 | \n",
" 1 | \n",
" 16450.89470 | \n",
" 1.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
"
\n",
" \n",
" 2770 | \n",
" 23 | \n",
" 18.715 | \n",
" 0 | \n",
" 21595.38229 | \n",
" 1.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 2771 | \n",
" 54 | \n",
" 31.600 | \n",
" 0 | \n",
" 9850.43200 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
\n",
"
2772 rows × 9 columns
\n",
"
"
],
"text/plain": [
" age bmi children charges sex_male region_northwest \\\n",
"0 19 27.900 0 16884.92400 0.0 0.0 \n",
"1 18 33.770 1 1725.55230 1.0 0.0 \n",
"2 28 33.000 3 4449.46200 1.0 0.0 \n",
"3 33 22.705 0 21984.47061 1.0 1.0 \n",
"4 32 28.880 0 3866.85520 1.0 1.0 \n",
"... ... ... ... ... ... ... \n",
"2767 47 45.320 1 8569.86180 0.0 0.0 \n",
"2768 21 34.600 0 2020.17700 0.0 0.0 \n",
"2769 19 26.030 1 16450.89470 1.0 1.0 \n",
"2770 23 18.715 0 21595.38229 1.0 1.0 \n",
"2771 54 31.600 0 9850.43200 1.0 0.0 \n",
"\n",
" region_southeast region_southwest smoker_yes \n",
"0 0.0 1.0 1.0 \n",
"1 1.0 0.0 0.0 \n",
"2 1.0 0.0 0.0 \n",
"3 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 \n",
"... ... ... ... \n",
"2767 1.0 0.0 0.0 \n",
"2768 0.0 1.0 0.0 \n",
"2769 0.0 0.0 1.0 \n",
"2770 0.0 0.0 0.0 \n",
"2771 0.0 1.0 0.0 \n",
"\n",
"[2772 rows x 9 columns]"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"import numpy as np # type: ignore\n",
"\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"random_state = 9\n",
"\n",
"df = pd.read_csv(\"data/Medical_insurance.csv\", index_col=False)\n",
"\n",
"encoder = OneHotEncoder(sparse_output=False, drop=\"first\")\n",
"\n",
"encoded_values = encoder.fit_transform(df[[\"sex\", \"region\", \"smoker\"]])\n",
"\n",
"encoded_columns = encoder.get_feature_names_out([\"sex\", \"region\", \"smoker\"])\n",
"\n",
"encoded_values_df = pd.DataFrame(encoded_values, columns=encoded_columns)\n",
"\n",
"df = pd.concat([df, encoded_values_df], axis=1)\n",
"\n",
"df = df.drop([\"sex\", \"smoker\", \"region\"], axis=1)\n",
"\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование выборок"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" bmi | \n",
" children | \n",
" sex_male | \n",
" region_northwest | \n",
" region_southeast | \n",
" region_southwest | \n",
" smoker_yes | \n",
"
\n",
" \n",
" \n",
" \n",
" 2146 | \n",
" 22 | \n",
" 34.580 | \n",
" 2 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 472 | \n",
" 19 | \n",
" 29.800 | \n",
" 0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 801 | \n",
" 64 | \n",
" 35.970 | \n",
" 0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 84 | \n",
" 37 | \n",
" 34.800 | \n",
" 2 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 1.0 | \n",
"
\n",
" \n",
" 2028 | \n",
" 61 | \n",
" 33.915 | \n",
" 0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 979 | \n",
" 36 | \n",
" 29.920 | \n",
" 0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 1475 | \n",
" 55 | \n",
" 26.980 | \n",
" 0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 2547 | \n",
" 34 | \n",
" 42.130 | \n",
" 2 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 2553 | \n",
" 29 | \n",
" 24.600 | \n",
" 2 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 1974 | \n",
" 61 | \n",
" 35.910 | \n",
" 0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
\n",
"
2217 rows × 8 columns
\n",
"
"
],
"text/plain": [
" age bmi children sex_male region_northwest region_southeast \\\n",
"2146 22 34.580 2 0.0 0.0 0.0 \n",
"472 19 29.800 0 0.0 0.0 0.0 \n",
"801 64 35.970 0 0.0 0.0 1.0 \n",
"84 37 34.800 2 0.0 0.0 0.0 \n",
"2028 61 33.915 0 1.0 0.0 0.0 \n",
"... ... ... ... ... ... ... \n",
"979 36 29.920 0 0.0 0.0 1.0 \n",
"1475 55 26.980 0 0.0 1.0 0.0 \n",
"2547 34 42.130 2 1.0 0.0 1.0 \n",
"2553 29 24.600 2 0.0 0.0 0.0 \n",
"1974 61 35.910 0 0.0 0.0 0.0 \n",
"\n",
" region_southwest smoker_yes \n",
"2146 0.0 0.0 \n",
"472 1.0 0.0 \n",
"801 0.0 0.0 \n",
"84 1.0 1.0 \n",
"2028 0.0 0.0 \n",
"... ... ... \n",
"979 0.0 0.0 \n",
"1475 0.0 0.0 \n",
"2547 0.0 0.0 \n",
"2553 1.0 0.0 \n",
"1974 0.0 0.0 \n",
"\n",
"[2217 rows x 8 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" charges | \n",
"
\n",
" \n",
" \n",
" \n",
" 2146 | \n",
" 3925.75820 | \n",
"
\n",
" \n",
" 472 | \n",
" 1744.46500 | \n",
"
\n",
" \n",
" 801 | \n",
" 14313.84630 | \n",
"
\n",
" \n",
" 84 | \n",
" 39836.51900 | \n",
"
\n",
" \n",
" 2028 | \n",
" 13143.86485 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 979 | \n",
" 4889.03680 | \n",
"
\n",
" \n",
" 1475 | \n",
" 11082.57720 | \n",
"
\n",
" \n",
" 2547 | \n",
" 5124.18870 | \n",
"
\n",
" \n",
" 2553 | \n",
" 4529.47700 | \n",
"
\n",
" \n",
" 1974 | \n",
" 13635.63790 | \n",
"
\n",
" \n",
"
\n",
"
2217 rows × 1 columns
\n",
"
"
],
"text/plain": [
" charges\n",
"2146 3925.75820\n",
"472 1744.46500\n",
"801 14313.84630\n",
"84 39836.51900\n",
"2028 13143.86485\n",
"... ...\n",
"979 4889.03680\n",
"1475 11082.57720\n",
"2547 5124.18870\n",
"2553 4529.47700\n",
"1974 13635.63790\n",
"\n",
"[2217 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" bmi | \n",
" children | \n",
" sex_male | \n",
" region_northwest | \n",
" region_southeast | \n",
" region_southwest | \n",
" smoker_yes | \n",
"
\n",
" \n",
" \n",
" \n",
" 1101 | \n",
" 53 | \n",
" 28.600 | \n",
" 3 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 2025 | \n",
" 56 | \n",
" 33.660 | \n",
" 4 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 307 | \n",
" 30 | \n",
" 33.330 | \n",
" 1 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 840 | \n",
" 21 | \n",
" 31.100 | \n",
" 0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 2090 | \n",
" 47 | \n",
" 29.545 | \n",
" 1 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 1587 | \n",
" 48 | \n",
" 32.230 | \n",
" 1 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 1177 | \n",
" 40 | \n",
" 27.400 | \n",
" 1 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 1259 | \n",
" 52 | \n",
" 23.180 | \n",
" 0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 1291 | \n",
" 19 | \n",
" 34.900 | \n",
" 0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 1.0 | \n",
"
\n",
" \n",
" 2040 | \n",
" 59 | \n",
" 35.200 | \n",
" 0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
\n",
"
555 rows × 8 columns
\n",
"
"
],
"text/plain": [
" age bmi children sex_male region_northwest region_southeast \\\n",
"1101 53 28.600 3 1.0 0.0 0.0 \n",
"2025 56 33.660 4 1.0 0.0 1.0 \n",
"307 30 33.330 1 0.0 0.0 1.0 \n",
"840 21 31.100 0 1.0 0.0 0.0 \n",
"2090 47 29.545 1 0.0 1.0 0.0 \n",
"... ... ... ... ... ... ... \n",
"1587 48 32.230 1 0.0 0.0 1.0 \n",
"1177 40 27.400 1 0.0 0.0 0.0 \n",
"1259 52 23.180 0 0.0 0.0 0.0 \n",
"1291 19 34.900 0 1.0 0.0 0.0 \n",
"2040 59 35.200 0 0.0 0.0 1.0 \n",
"\n",
" region_southwest smoker_yes \n",
"1101 1.0 0.0 \n",
"2025 0.0 0.0 \n",
"307 0.0 0.0 \n",
"840 1.0 0.0 \n",
"2090 0.0 0.0 \n",
"... ... ... \n",
"1587 0.0 0.0 \n",
"1177 1.0 0.0 \n",
"1259 0.0 0.0 \n",
"1291 1.0 1.0 \n",
"2040 0.0 0.0 \n",
"\n",
"[555 rows x 8 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" charges | \n",
"
\n",
" \n",
" \n",
" \n",
" 1101 | \n",
" 11253.42100 | \n",
"
\n",
" \n",
" 2025 | \n",
" 12949.15540 | \n",
"
\n",
" \n",
" 307 | \n",
" 4151.02870 | \n",
"
\n",
" \n",
" 840 | \n",
" 1526.31200 | \n",
"
\n",
" \n",
" 2090 | \n",
" 8930.93455 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 1587 | \n",
" 8871.15170 | \n",
"
\n",
" \n",
" 1177 | \n",
" 6496.88600 | \n",
"
\n",
" \n",
" 1259 | \n",
" 10197.77220 | \n",
"
\n",
" \n",
" 1291 | \n",
" 34828.65400 | \n",
"
\n",
" \n",
" 2040 | \n",
" 12244.53100 | \n",
"
\n",
" \n",
"
\n",
"
555 rows × 1 columns
\n",
"
"
],
"text/plain": [
" charges\n",
"1101 11253.42100\n",
"2025 12949.15540\n",
"307 4151.02870\n",
"840 1526.31200\n",
"2090 8930.93455\n",
"... ...\n",
"1587 8871.15170\n",
"1177 6496.88600\n",
"1259 10197.77220\n",
"1291 34828.65400\n",
"2040 12244.53100\n",
"\n",
"[555 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from utils import split_stratified_into_train_val_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df,\n",
" stratify_colname=\"age\",\n",
" target_colname=\"charges\",\n",
" frac_train=0.80,\n",
" frac_val=0,\n",
" frac_test=0.20,\n",
" random_state=random_state,\n",
")\n",
"\n",
"X_train = X_train.drop([\"charges\"], axis=1)\n",
"X_test = X_test.drop([\"charges\"], axis=1)\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Определение перечня алгоритмов решения задачи аппроксимации (регрессии)"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
"\n",
"random_state = 9\n",
"\n",
"models = {\n",
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPRegressor(\n",
" activation=\"tanh\",\n",
" hidden_layer_sizes=(3,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Определение функции для стандартизации числовых значений для MLP"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
"from pandas import DataFrame\n",
"from sklearn import preprocessing\n",
"\n",
"\n",
"stndart_scaler = preprocessing.StandardScaler()\n",
"\n",
"\n",
"def std_q(df: DataFrame) -> DataFrame:\n",
" df[\"age\"] = np.array(stndart_scaler.fit_transform(df[\"age\"].to_numpy().reshape(-1, 1))).reshape(\n",
" df[\"age\"].shape\n",
" )\n",
" df[\"bmi\"] = np.array(\n",
" stndart_scaler.fit_transform(df[\"bmi\"].to_numpy().reshape(-1, 1))\n",
" ).reshape(df[\"bmi\"].shape)\n",
" df[\"children\"] = np.array(\n",
" stndart_scaler.fit_transform(df[\"children\"].to_numpy().reshape(-1, 1))\n",
" ).reshape(df[\"children\"].shape)\n",
" return df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение и оценка моделей с помощью различных алгоритмов"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: linear\n",
"Model: linear_poly\n",
"Model: linear_interact\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import math\n",
"from pandas import DataFrame\n",
"from sklearn import metrics\n",
"\n",
"for model_name in models.keys():\n",
" print(f\"Model: {model_name}\")\n",
"\n",
" x_train: DataFrame = X_train.copy()\n",
" x_test: DataFrame = X_test.copy()\n",
"\n",
" if model_name == \"mlp\":\n",
" x_train = std_q(x_train)\n",
" x_test = std_q(x_test)\n",
" \n",
" fitted_model = models[model_name][\"model\"].fit(\n",
" x_train.values, y_train.values.ravel()\n",
" )\n",
" y_train_pred = fitted_model.predict(x_train.values)\n",
" y_test_pred = fitted_model.predict(x_test.values)\n",
" models[model_name][\"fitted\"] = fitted_model\n",
" models[model_name][\"train_preds\"] = y_train_pred\n",
" models[model_name][\"preds\"] = y_test_pred\n",
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_train, y_train_pred)\n",
" )\n",
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод результатов оценки"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
" \n",
" \n",
" | \n",
" RMSE_train | \n",
" RMSE_test | \n",
" RMAE_test | \n",
" R2_test | \n",
"
\n",
" \n",
" \n",
" \n",
" random_forest | \n",
" 3221.469707 | \n",
" 3953.661053 | \n",
" 45.741609 | \n",
" 0.901103 | \n",
"
\n",
" \n",
" decision_tree | \n",
" 3643.279193 | \n",
" 4288.040726 | \n",
" 47.359073 | \n",
" 0.883668 | \n",
"
\n",
" \n",
" linear_poly | \n",
" 4731.024654 | \n",
" 4868.817371 | \n",
" 54.257745 | \n",
" 0.850021 | \n",
"
\n",
" \n",
" linear_interact | \n",
" 4776.393716 | \n",
" 4938.699556 | \n",
" 54.641209 | \n",
" 0.845685 | \n",
"
\n",
" \n",
" ridge | \n",
" 6028.427617 | \n",
" 6216.544081 | \n",
" 65.584948 | \n",
" 0.755499 | \n",
"
\n",
" \n",
" linear | \n",
" 6028.426993 | \n",
" 6216.588829 | \n",
" 65.580879 | \n",
" 0.755496 | \n",
"
\n",
" \n",
" knn | \n",
" 8230.959070 | \n",
" 9715.102581 | \n",
" 81.129201 | \n",
" 0.402859 | \n",
"
\n",
" \n",
" mlp | \n",
" 17848.198895 | \n",
" 18518.275054 | \n",
" 116.605174 | \n",
" -1.169619 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
""
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод реального и \"спрогнозированного\" результата для обучающей и тестовой выборок"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Получение лучшей модели"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'random_forest'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод для обучающей выборки"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" bmi | \n",
" children | \n",
" sex_male | \n",
" region_northwest | \n",
" region_southeast | \n",
" region_southwest | \n",
" smoker_yes | \n",
" charges | \n",
" ChargesPred | \n",
"
\n",
" \n",
" \n",
" \n",
" 2146 | \n",
" 22 | \n",
" 34.580 | \n",
" 2 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 3925.75820 | \n",
" 4868.448184 | \n",
"
\n",
" \n",
" 472 | \n",
" 19 | \n",
" 29.800 | \n",
" 0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 1744.46500 | \n",
" 3011.805771 | \n",
"
\n",
" \n",
" 801 | \n",
" 64 | \n",
" 35.970 | \n",
" 0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 14313.84630 | \n",
" 13841.766282 | \n",
"
\n",
" \n",
" 84 | \n",
" 37 | \n",
" 34.800 | \n",
" 2 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 1.0 | \n",
" 39836.51900 | \n",
" 39427.673528 | \n",
"
\n",
" \n",
" 2028 | \n",
" 61 | \n",
" 33.915 | \n",
" 0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 13143.86485 | \n",
" 13575.291528 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age bmi children sex_male region_northwest region_southeast \\\n",
"2146 22 34.580 2 0.0 0.0 0.0 \n",
"472 19 29.800 0 0.0 0.0 0.0 \n",
"801 64 35.970 0 0.0 0.0 1.0 \n",
"84 37 34.800 2 0.0 0.0 0.0 \n",
"2028 61 33.915 0 1.0 0.0 0.0 \n",
"\n",
" region_southwest smoker_yes charges ChargesPred \n",
"2146 0.0 0.0 3925.75820 4868.448184 \n",
"472 1.0 0.0 1744.46500 3011.805771 \n",
"801 0.0 0.0 14313.84630 13841.766282 \n",
"84 1.0 1.0 39836.51900 39427.673528 \n",
"2028 0.0 0.0 13143.86485 13575.291528 "
]
},
"execution_count": 81,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.concat(\n",
" [\n",
" X_train,\n",
" y_train,\n",
" pd.Series(\n",
" models[best_model][\"train_preds\"],\n",
" index=y_train.index,\n",
" name=\"ChargesPred\",\n",
" ),\n",
" ],\n",
" axis=1,\n",
").head(5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод для тестовой выборки"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" bmi | \n",
" children | \n",
" sex_male | \n",
" region_northwest | \n",
" region_southeast | \n",
" region_southwest | \n",
" smoker_yes | \n",
" charges | \n",
" ChargesPred | \n",
"
\n",
" \n",
" \n",
" \n",
" 1101 | \n",
" 53 | \n",
" 28.600 | \n",
" 3 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 11253.42100 | \n",
" 12139.772544 | \n",
"
\n",
" \n",
" 2025 | \n",
" 56 | \n",
" 33.660 | \n",
" 4 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 12949.15540 | \n",
" 14977.306757 | \n",
"
\n",
" \n",
" 307 | \n",
" 30 | \n",
" 33.330 | \n",
" 1 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 4151.02870 | \n",
" 5778.492115 | \n",
"
\n",
" \n",
" 840 | \n",
" 21 | \n",
" 31.100 | \n",
" 0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 1526.31200 | \n",
" 3324.843009 | \n",
"
\n",
" \n",
" 2090 | \n",
" 47 | \n",
" 29.545 | \n",
" 1 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 8930.93455 | \n",
" 11318.629065 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age bmi children sex_male region_northwest region_southeast \\\n",
"1101 53 28.600 3 1.0 0.0 0.0 \n",
"2025 56 33.660 4 1.0 0.0 1.0 \n",
"307 30 33.330 1 0.0 0.0 1.0 \n",
"840 21 31.100 0 1.0 0.0 0.0 \n",
"2090 47 29.545 1 0.0 1.0 0.0 \n",
"\n",
" region_southwest smoker_yes charges ChargesPred \n",
"1101 1.0 0.0 11253.42100 12139.772544 \n",
"2025 0.0 0.0 12949.15540 14977.306757 \n",
"307 0.0 0.0 4151.02870 5778.492115 \n",
"840 1.0 0.0 1526.31200 3324.843009 \n",
"2090 0.0 0.0 8930.93455 11318.629065 "
]
},
"execution_count": 82,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.concat(\n",
" [\n",
" X_test,\n",
" y_test,\n",
" pd.Series(\n",
" models[best_model][\"preds\"],\n",
" index=y_test.index,\n",
" name=\"ChargesPred\",\n",
" ),\n",
" ],\n",
" axis=1,\n",
").head(5)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}