1594 lines
51 KiB
Plaintext
1594 lines
51 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Загрузка данных"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>age</th>\n",
|
||
" <th>bmi</th>\n",
|
||
" <th>children</th>\n",
|
||
" <th>charges</th>\n",
|
||
" <th>sex_male</th>\n",
|
||
" <th>region_northwest</th>\n",
|
||
" <th>region_southeast</th>\n",
|
||
" <th>region_southwest</th>\n",
|
||
" <th>smoker_yes</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>19</td>\n",
|
||
" <td>27.900</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>16884.92400</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>18</td>\n",
|
||
" <td>33.770</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>1725.55230</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>28</td>\n",
|
||
" <td>33.000</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>4449.46200</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>33</td>\n",
|
||
" <td>22.705</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>21984.47061</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>32</td>\n",
|
||
" <td>28.880</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>3866.85520</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2767</th>\n",
|
||
" <td>47</td>\n",
|
||
" <td>45.320</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>8569.86180</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2768</th>\n",
|
||
" <td>21</td>\n",
|
||
" <td>34.600</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>2020.17700</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2769</th>\n",
|
||
" <td>19</td>\n",
|
||
" <td>26.030</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>16450.89470</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2770</th>\n",
|
||
" <td>23</td>\n",
|
||
" <td>18.715</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>21595.38229</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2771</th>\n",
|
||
" <td>54</td>\n",
|
||
" <td>31.600</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>9850.43200</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>2772 rows × 9 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" age bmi children charges sex_male region_northwest \\\n",
|
||
"0 19 27.900 0 16884.92400 0.0 0.0 \n",
|
||
"1 18 33.770 1 1725.55230 1.0 0.0 \n",
|
||
"2 28 33.000 3 4449.46200 1.0 0.0 \n",
|
||
"3 33 22.705 0 21984.47061 1.0 1.0 \n",
|
||
"4 32 28.880 0 3866.85520 1.0 1.0 \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"2767 47 45.320 1 8569.86180 0.0 0.0 \n",
|
||
"2768 21 34.600 0 2020.17700 0.0 0.0 \n",
|
||
"2769 19 26.030 1 16450.89470 1.0 1.0 \n",
|
||
"2770 23 18.715 0 21595.38229 1.0 1.0 \n",
|
||
"2771 54 31.600 0 9850.43200 1.0 0.0 \n",
|
||
"\n",
|
||
" region_southeast region_southwest smoker_yes \n",
|
||
"0 0.0 1.0 1.0 \n",
|
||
"1 1.0 0.0 0.0 \n",
|
||
"2 1.0 0.0 0.0 \n",
|
||
"3 0.0 0.0 0.0 \n",
|
||
"4 0.0 0.0 0.0 \n",
|
||
"... ... ... ... \n",
|
||
"2767 1.0 0.0 0.0 \n",
|
||
"2768 0.0 1.0 0.0 \n",
|
||
"2769 0.0 0.0 1.0 \n",
|
||
"2770 0.0 0.0 0.0 \n",
|
||
"2771 0.0 1.0 0.0 \n",
|
||
"\n",
|
||
"[2772 rows x 9 columns]"
|
||
]
|
||
},
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"from sklearn.preprocessing import OneHotEncoder\n",
|
||
"import numpy as np # type: ignore\n",
|
||
"\n",
|
||
"from sklearn import set_config\n",
|
||
"\n",
|
||
"set_config(transform_output=\"pandas\")\n",
|
||
"\n",
|
||
"random_state = 9\n",
|
||
"\n",
|
||
"df = pd.read_csv(\"data/Medical_insurance.csv\", index_col=False)\n",
|
||
"\n",
|
||
"encoder = OneHotEncoder(sparse_output=False, drop=\"first\")\n",
|
||
"\n",
|
||
"encoded_values = encoder.fit_transform(df[[\"sex\", \"region\", \"smoker\"]])\n",
|
||
"\n",
|
||
"encoded_columns = encoder.get_feature_names_out([\"sex\", \"region\", \"smoker\"])\n",
|
||
"\n",
|
||
"encoded_values_df = pd.DataFrame(encoded_values, columns=encoded_columns)\n",
|
||
"\n",
|
||
"df = pd.concat([df, encoded_values_df], axis=1)\n",
|
||
"\n",
|
||
"df = df.drop([\"sex\", \"smoker\", \"region\"], axis=1)\n",
|
||
"\n",
|
||
"df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Формирование выборок"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>age</th>\n",
|
||
" <th>bmi</th>\n",
|
||
" <th>children</th>\n",
|
||
" <th>sex_male</th>\n",
|
||
" <th>region_northwest</th>\n",
|
||
" <th>region_southeast</th>\n",
|
||
" <th>region_southwest</th>\n",
|
||
" <th>smoker_yes</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>2146</th>\n",
|
||
" <td>22</td>\n",
|
||
" <td>34.580</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>472</th>\n",
|
||
" <td>19</td>\n",
|
||
" <td>29.800</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>801</th>\n",
|
||
" <td>64</td>\n",
|
||
" <td>35.970</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>84</th>\n",
|
||
" <td>37</td>\n",
|
||
" <td>34.800</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2028</th>\n",
|
||
" <td>61</td>\n",
|
||
" <td>33.915</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>979</th>\n",
|
||
" <td>36</td>\n",
|
||
" <td>29.920</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1475</th>\n",
|
||
" <td>55</td>\n",
|
||
" <td>26.980</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2547</th>\n",
|
||
" <td>34</td>\n",
|
||
" <td>42.130</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2553</th>\n",
|
||
" <td>29</td>\n",
|
||
" <td>24.600</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1974</th>\n",
|
||
" <td>61</td>\n",
|
||
" <td>35.910</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>2217 rows × 8 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" age bmi children sex_male region_northwest region_southeast \\\n",
|
||
"2146 22 34.580 2 0.0 0.0 0.0 \n",
|
||
"472 19 29.800 0 0.0 0.0 0.0 \n",
|
||
"801 64 35.970 0 0.0 0.0 1.0 \n",
|
||
"84 37 34.800 2 0.0 0.0 0.0 \n",
|
||
"2028 61 33.915 0 1.0 0.0 0.0 \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"979 36 29.920 0 0.0 0.0 1.0 \n",
|
||
"1475 55 26.980 0 0.0 1.0 0.0 \n",
|
||
"2547 34 42.130 2 1.0 0.0 1.0 \n",
|
||
"2553 29 24.600 2 0.0 0.0 0.0 \n",
|
||
"1974 61 35.910 0 0.0 0.0 0.0 \n",
|
||
"\n",
|
||
" region_southwest smoker_yes \n",
|
||
"2146 0.0 0.0 \n",
|
||
"472 1.0 0.0 \n",
|
||
"801 0.0 0.0 \n",
|
||
"84 1.0 1.0 \n",
|
||
"2028 0.0 0.0 \n",
|
||
"... ... ... \n",
|
||
"979 0.0 0.0 \n",
|
||
"1475 0.0 0.0 \n",
|
||
"2547 0.0 0.0 \n",
|
||
"2553 1.0 0.0 \n",
|
||
"1974 0.0 0.0 \n",
|
||
"\n",
|
||
"[2217 rows x 8 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>charges</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>2146</th>\n",
|
||
" <td>3925.75820</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>472</th>\n",
|
||
" <td>1744.46500</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>801</th>\n",
|
||
" <td>14313.84630</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>84</th>\n",
|
||
" <td>39836.51900</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2028</th>\n",
|
||
" <td>13143.86485</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>979</th>\n",
|
||
" <td>4889.03680</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1475</th>\n",
|
||
" <td>11082.57720</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2547</th>\n",
|
||
" <td>5124.18870</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2553</th>\n",
|
||
" <td>4529.47700</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1974</th>\n",
|
||
" <td>13635.63790</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>2217 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" charges\n",
|
||
"2146 3925.75820\n",
|
||
"472 1744.46500\n",
|
||
"801 14313.84630\n",
|
||
"84 39836.51900\n",
|
||
"2028 13143.86485\n",
|
||
"... ...\n",
|
||
"979 4889.03680\n",
|
||
"1475 11082.57720\n",
|
||
"2547 5124.18870\n",
|
||
"2553 4529.47700\n",
|
||
"1974 13635.63790\n",
|
||
"\n",
|
||
"[2217 rows x 1 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>age</th>\n",
|
||
" <th>bmi</th>\n",
|
||
" <th>children</th>\n",
|
||
" <th>sex_male</th>\n",
|
||
" <th>region_northwest</th>\n",
|
||
" <th>region_southeast</th>\n",
|
||
" <th>region_southwest</th>\n",
|
||
" <th>smoker_yes</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>1101</th>\n",
|
||
" <td>53</td>\n",
|
||
" <td>28.600</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2025</th>\n",
|
||
" <td>56</td>\n",
|
||
" <td>33.660</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>307</th>\n",
|
||
" <td>30</td>\n",
|
||
" <td>33.330</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>840</th>\n",
|
||
" <td>21</td>\n",
|
||
" <td>31.100</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2090</th>\n",
|
||
" <td>47</td>\n",
|
||
" <td>29.545</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1587</th>\n",
|
||
" <td>48</td>\n",
|
||
" <td>32.230</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1177</th>\n",
|
||
" <td>40</td>\n",
|
||
" <td>27.400</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1259</th>\n",
|
||
" <td>52</td>\n",
|
||
" <td>23.180</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1291</th>\n",
|
||
" <td>19</td>\n",
|
||
" <td>34.900</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2040</th>\n",
|
||
" <td>59</td>\n",
|
||
" <td>35.200</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>555 rows × 8 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" age bmi children sex_male region_northwest region_southeast \\\n",
|
||
"1101 53 28.600 3 1.0 0.0 0.0 \n",
|
||
"2025 56 33.660 4 1.0 0.0 1.0 \n",
|
||
"307 30 33.330 1 0.0 0.0 1.0 \n",
|
||
"840 21 31.100 0 1.0 0.0 0.0 \n",
|
||
"2090 47 29.545 1 0.0 1.0 0.0 \n",
|
||
"... ... ... ... ... ... ... \n",
|
||
"1587 48 32.230 1 0.0 0.0 1.0 \n",
|
||
"1177 40 27.400 1 0.0 0.0 0.0 \n",
|
||
"1259 52 23.180 0 0.0 0.0 0.0 \n",
|
||
"1291 19 34.900 0 1.0 0.0 0.0 \n",
|
||
"2040 59 35.200 0 0.0 0.0 1.0 \n",
|
||
"\n",
|
||
" region_southwest smoker_yes \n",
|
||
"1101 1.0 0.0 \n",
|
||
"2025 0.0 0.0 \n",
|
||
"307 0.0 0.0 \n",
|
||
"840 1.0 0.0 \n",
|
||
"2090 0.0 0.0 \n",
|
||
"... ... ... \n",
|
||
"1587 0.0 0.0 \n",
|
||
"1177 1.0 0.0 \n",
|
||
"1259 0.0 0.0 \n",
|
||
"1291 1.0 1.0 \n",
|
||
"2040 0.0 0.0 \n",
|
||
"\n",
|
||
"[555 rows x 8 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>charges</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>1101</th>\n",
|
||
" <td>11253.42100</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2025</th>\n",
|
||
" <td>12949.15540</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>307</th>\n",
|
||
" <td>4151.02870</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>840</th>\n",
|
||
" <td>1526.31200</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2090</th>\n",
|
||
" <td>8930.93455</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1587</th>\n",
|
||
" <td>8871.15170</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1177</th>\n",
|
||
" <td>6496.88600</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1259</th>\n",
|
||
" <td>10197.77220</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1291</th>\n",
|
||
" <td>34828.65400</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2040</th>\n",
|
||
" <td>12244.53100</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>555 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" charges\n",
|
||
"1101 11253.42100\n",
|
||
"2025 12949.15540\n",
|
||
"307 4151.02870\n",
|
||
"840 1526.31200\n",
|
||
"2090 8930.93455\n",
|
||
"... ...\n",
|
||
"1587 8871.15170\n",
|
||
"1177 6496.88600\n",
|
||
"1259 10197.77220\n",
|
||
"1291 34828.65400\n",
|
||
"2040 12244.53100\n",
|
||
"\n",
|
||
"[555 rows x 1 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from utils import split_stratified_into_train_val_test\n",
|
||
"\n",
|
||
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
|
||
" df,\n",
|
||
" stratify_colname=\"age\",\n",
|
||
" target_colname=\"charges\",\n",
|
||
" frac_train=0.80,\n",
|
||
" frac_val=0,\n",
|
||
" frac_test=0.20,\n",
|
||
" random_state=random_state,\n",
|
||
")\n",
|
||
"\n",
|
||
"X_train = X_train.drop([\"charges\"], axis=1)\n",
|
||
"X_test = X_test.drop([\"charges\"], axis=1)\n",
|
||
"\n",
|
||
"display(\"X_train\", X_train)\n",
|
||
"display(\"y_train\", y_train)\n",
|
||
"\n",
|
||
"display(\"X_test\", X_test)\n",
|
||
"display(\"y_test\", y_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Определение перечня алгоритмов решения задачи аппроксимации (регрессии)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.pipeline import make_pipeline\n",
|
||
"from sklearn.preprocessing import PolynomialFeatures\n",
|
||
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
|
||
"\n",
|
||
"random_state = 9\n",
|
||
"\n",
|
||
"models = {\n",
|
||
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
|
||
" \"linear_poly\": {\n",
|
||
" \"model\": make_pipeline(\n",
|
||
" PolynomialFeatures(degree=2),\n",
|
||
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
||
" )\n",
|
||
" },\n",
|
||
" \"linear_interact\": {\n",
|
||
" \"model\": make_pipeline(\n",
|
||
" PolynomialFeatures(interaction_only=True),\n",
|
||
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
||
" )\n",
|
||
" },\n",
|
||
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
|
||
" \"decision_tree\": {\n",
|
||
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
|
||
" },\n",
|
||
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
|
||
" \"random_forest\": {\n",
|
||
" \"model\": ensemble.RandomForestRegressor(\n",
|
||
" max_depth=7, random_state=random_state, n_jobs=-1\n",
|
||
" )\n",
|
||
" },\n",
|
||
" \"mlp\": {\n",
|
||
" \"model\": neural_network.MLPRegressor(\n",
|
||
" activation=\"tanh\",\n",
|
||
" hidden_layer_sizes=(3,),\n",
|
||
" max_iter=500,\n",
|
||
" early_stopping=True,\n",
|
||
" random_state=random_state,\n",
|
||
" )\n",
|
||
" },\n",
|
||
"}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Определение функции для стандартизации числовых значений для MLP"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from pandas import DataFrame\n",
|
||
"from sklearn import preprocessing\n",
|
||
"\n",
|
||
"\n",
|
||
"stndart_scaler = preprocessing.StandardScaler()\n",
|
||
"\n",
|
||
"\n",
|
||
"def std_q(df: DataFrame) -> DataFrame:\n",
|
||
" df[\"age\"] = np.array(stndart_scaler.fit_transform(df[\"age\"].to_numpy().reshape(-1, 1))).reshape(\n",
|
||
" df[\"age\"].shape\n",
|
||
" )\n",
|
||
" df[\"bmi\"] = np.array(\n",
|
||
" stndart_scaler.fit_transform(df[\"bmi\"].to_numpy().reshape(-1, 1))\n",
|
||
" ).reshape(df[\"bmi\"].shape)\n",
|
||
" df[\"children\"] = np.array(\n",
|
||
" stndart_scaler.fit_transform(df[\"children\"].to_numpy().reshape(-1, 1))\n",
|
||
" ).reshape(df[\"children\"].shape)\n",
|
||
" return df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Обучение и оценка моделей с помощью различных алгоритмов"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Model: linear\n",
|
||
"Model: linear_poly\n",
|
||
"Model: linear_interact\n",
|
||
"Model: ridge\n",
|
||
"Model: decision_tree\n",
|
||
"Model: knn\n",
|
||
"Model: random_forest\n",
|
||
"Model: mlp\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import math\n",
|
||
"from pandas import DataFrame\n",
|
||
"from sklearn import metrics\n",
|
||
"\n",
|
||
"for model_name in models.keys():\n",
|
||
" print(f\"Model: {model_name}\")\n",
|
||
"\n",
|
||
" x_train: DataFrame = X_train.copy()\n",
|
||
" x_test: DataFrame = X_test.copy()\n",
|
||
"\n",
|
||
" if model_name == \"mlp\":\n",
|
||
" x_train = std_q(x_train)\n",
|
||
" x_test = std_q(x_test)\n",
|
||
" \n",
|
||
" fitted_model = models[model_name][\"model\"].fit(\n",
|
||
" x_train.values, y_train.values.ravel()\n",
|
||
" )\n",
|
||
" y_train_pred = fitted_model.predict(x_train.values)\n",
|
||
" y_test_pred = fitted_model.predict(x_test.values)\n",
|
||
" models[model_name][\"fitted\"] = fitted_model\n",
|
||
" models[model_name][\"train_preds\"] = y_train_pred\n",
|
||
" models[model_name][\"preds\"] = y_test_pred\n",
|
||
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
|
||
" metrics.mean_squared_error(y_train, y_train_pred)\n",
|
||
" )\n",
|
||
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
|
||
" metrics.mean_squared_error(y_test, y_test_pred)\n",
|
||
" )\n",
|
||
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
|
||
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
|
||
" )\n",
|
||
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Вывод результатов оценки"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_3759d_row0_col0, #T_3759d_row0_col1 {\n",
|
||
" background-color: #26818e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row0_col2, #T_3759d_row7_col3 {\n",
|
||
" background-color: #4e02a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row0_col3, #T_3759d_row7_col2 {\n",
|
||
" background-color: #da5a6a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row1_col0 {\n",
|
||
" background-color: #25838e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row1_col1 {\n",
|
||
" background-color: #26828e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row1_col2 {\n",
|
||
" background-color: #5102a3;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row1_col3 {\n",
|
||
" background-color: #d9586a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row2_col0 {\n",
|
||
" background-color: #228b8d;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row2_col1, #T_3759d_row3_col1 {\n",
|
||
" background-color: #24878e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row2_col2 {\n",
|
||
" background-color: #6300a7;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row2_col3, #T_3759d_row3_col3 {\n",
|
||
" background-color: #d7566c;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row3_col0 {\n",
|
||
" background-color: #228c8d;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row3_col2 {\n",
|
||
" background-color: #6400a7;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row4_col0, #T_3759d_row5_col0 {\n",
|
||
" background-color: #1f948c;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row4_col1, #T_3759d_row5_col1 {\n",
|
||
" background-color: #21918c;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row4_col2, #T_3759d_row5_col2 {\n",
|
||
" background-color: #7e03a8;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row4_col3, #T_3759d_row5_col3 {\n",
|
||
" background-color: #d35171;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row6_col0 {\n",
|
||
" background-color: #20a486;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row6_col1 {\n",
|
||
" background-color: #24aa83;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row6_col2 {\n",
|
||
" background-color: #a01a9c;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row6_col3 {\n",
|
||
" background-color: #c13b82;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_3759d_row7_col0, #T_3759d_row7_col1 {\n",
|
||
" background-color: #a8db34;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_3759d\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_3759d_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
|
||
" <th id=\"T_3759d_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
|
||
" <th id=\"T_3759d_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
|
||
" <th id=\"T_3759d_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_3759d_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
|
||
" <td id=\"T_3759d_row0_col0\" class=\"data row0 col0\" >3221.469707</td>\n",
|
||
" <td id=\"T_3759d_row0_col1\" class=\"data row0 col1\" >3953.661053</td>\n",
|
||
" <td id=\"T_3759d_row0_col2\" class=\"data row0 col2\" >45.741609</td>\n",
|
||
" <td id=\"T_3759d_row0_col3\" class=\"data row0 col3\" >0.901103</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_3759d_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
|
||
" <td id=\"T_3759d_row1_col0\" class=\"data row1 col0\" >3643.279193</td>\n",
|
||
" <td id=\"T_3759d_row1_col1\" class=\"data row1 col1\" >4288.040726</td>\n",
|
||
" <td id=\"T_3759d_row1_col2\" class=\"data row1 col2\" >47.359073</td>\n",
|
||
" <td id=\"T_3759d_row1_col3\" class=\"data row1 col3\" >0.883668</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_3759d_level0_row2\" class=\"row_heading level0 row2\" >linear_poly</th>\n",
|
||
" <td id=\"T_3759d_row2_col0\" class=\"data row2 col0\" >4731.024654</td>\n",
|
||
" <td id=\"T_3759d_row2_col1\" class=\"data row2 col1\" >4868.817371</td>\n",
|
||
" <td id=\"T_3759d_row2_col2\" class=\"data row2 col2\" >54.257745</td>\n",
|
||
" <td id=\"T_3759d_row2_col3\" class=\"data row2 col3\" >0.850021</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_3759d_level0_row3\" class=\"row_heading level0 row3\" >linear_interact</th>\n",
|
||
" <td id=\"T_3759d_row3_col0\" class=\"data row3 col0\" >4776.393716</td>\n",
|
||
" <td id=\"T_3759d_row3_col1\" class=\"data row3 col1\" >4938.699556</td>\n",
|
||
" <td id=\"T_3759d_row3_col2\" class=\"data row3 col2\" >54.641209</td>\n",
|
||
" <td id=\"T_3759d_row3_col3\" class=\"data row3 col3\" >0.845685</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_3759d_level0_row4\" class=\"row_heading level0 row4\" >ridge</th>\n",
|
||
" <td id=\"T_3759d_row4_col0\" class=\"data row4 col0\" >6028.427617</td>\n",
|
||
" <td id=\"T_3759d_row4_col1\" class=\"data row4 col1\" >6216.544081</td>\n",
|
||
" <td id=\"T_3759d_row4_col2\" class=\"data row4 col2\" >65.584948</td>\n",
|
||
" <td id=\"T_3759d_row4_col3\" class=\"data row4 col3\" >0.755499</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_3759d_level0_row5\" class=\"row_heading level0 row5\" >linear</th>\n",
|
||
" <td id=\"T_3759d_row5_col0\" class=\"data row5 col0\" >6028.426993</td>\n",
|
||
" <td id=\"T_3759d_row5_col1\" class=\"data row5 col1\" >6216.588829</td>\n",
|
||
" <td id=\"T_3759d_row5_col2\" class=\"data row5 col2\" >65.580879</td>\n",
|
||
" <td id=\"T_3759d_row5_col3\" class=\"data row5 col3\" >0.755496</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_3759d_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
|
||
" <td id=\"T_3759d_row6_col0\" class=\"data row6 col0\" >8230.959070</td>\n",
|
||
" <td id=\"T_3759d_row6_col1\" class=\"data row6 col1\" >9715.102581</td>\n",
|
||
" <td id=\"T_3759d_row6_col2\" class=\"data row6 col2\" >81.129201</td>\n",
|
||
" <td id=\"T_3759d_row6_col3\" class=\"data row6 col3\" >0.402859</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_3759d_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
|
||
" <td id=\"T_3759d_row7_col0\" class=\"data row7 col0\" >17848.198895</td>\n",
|
||
" <td id=\"T_3759d_row7_col1\" class=\"data row7 col1\" >18518.275054</td>\n",
|
||
" <td id=\"T_3759d_row7_col2\" class=\"data row7 col2\" >116.605174</td>\n",
|
||
" <td id=\"T_3759d_row7_col3\" class=\"data row7 col3\" >-1.169619</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x203e1a15460>"
|
||
]
|
||
},
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
|
||
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
|
||
"]\n",
|
||
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient( # type: ignore\n",
|
||
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
|
||
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Вывод реального и \"спрогнозированного\" результата для обучающей и тестовой выборок"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Получение лучшей модели"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'random_forest'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name) # type: ignore\n",
|
||
"\n",
|
||
"display(best_model)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Вывод для обучающей выборки"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>age</th>\n",
|
||
" <th>bmi</th>\n",
|
||
" <th>children</th>\n",
|
||
" <th>sex_male</th>\n",
|
||
" <th>region_northwest</th>\n",
|
||
" <th>region_southeast</th>\n",
|
||
" <th>region_southwest</th>\n",
|
||
" <th>smoker_yes</th>\n",
|
||
" <th>charges</th>\n",
|
||
" <th>ChargesPred</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>2146</th>\n",
|
||
" <td>22</td>\n",
|
||
" <td>34.580</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>3925.75820</td>\n",
|
||
" <td>4868.448184</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>472</th>\n",
|
||
" <td>19</td>\n",
|
||
" <td>29.800</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1744.46500</td>\n",
|
||
" <td>3011.805771</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>801</th>\n",
|
||
" <td>64</td>\n",
|
||
" <td>35.970</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>14313.84630</td>\n",
|
||
" <td>13841.766282</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>84</th>\n",
|
||
" <td>37</td>\n",
|
||
" <td>34.800</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>39836.51900</td>\n",
|
||
" <td>39427.673528</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2028</th>\n",
|
||
" <td>61</td>\n",
|
||
" <td>33.915</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>13143.86485</td>\n",
|
||
" <td>13575.291528</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" age bmi children sex_male region_northwest region_southeast \\\n",
|
||
"2146 22 34.580 2 0.0 0.0 0.0 \n",
|
||
"472 19 29.800 0 0.0 0.0 0.0 \n",
|
||
"801 64 35.970 0 0.0 0.0 1.0 \n",
|
||
"84 37 34.800 2 0.0 0.0 0.0 \n",
|
||
"2028 61 33.915 0 1.0 0.0 0.0 \n",
|
||
"\n",
|
||
" region_southwest smoker_yes charges ChargesPred \n",
|
||
"2146 0.0 0.0 3925.75820 4868.448184 \n",
|
||
"472 1.0 0.0 1744.46500 3011.805771 \n",
|
||
"801 0.0 0.0 14313.84630 13841.766282 \n",
|
||
"84 1.0 1.0 39836.51900 39427.673528 \n",
|
||
"2028 0.0 0.0 13143.86485 13575.291528 "
|
||
]
|
||
},
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"pd.concat(\n",
|
||
" [\n",
|
||
" X_train,\n",
|
||
" y_train,\n",
|
||
" pd.Series(\n",
|
||
" models[best_model][\"train_preds\"],\n",
|
||
" index=y_train.index,\n",
|
||
" name=\"ChargesPred\",\n",
|
||
" ),\n",
|
||
" ],\n",
|
||
" axis=1,\n",
|
||
").head(5)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Вывод для тестовой выборки"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>age</th>\n",
|
||
" <th>bmi</th>\n",
|
||
" <th>children</th>\n",
|
||
" <th>sex_male</th>\n",
|
||
" <th>region_northwest</th>\n",
|
||
" <th>region_southeast</th>\n",
|
||
" <th>region_southwest</th>\n",
|
||
" <th>smoker_yes</th>\n",
|
||
" <th>charges</th>\n",
|
||
" <th>ChargesPred</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>1101</th>\n",
|
||
" <td>53</td>\n",
|
||
" <td>28.600</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>11253.42100</td>\n",
|
||
" <td>12139.772544</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2025</th>\n",
|
||
" <td>56</td>\n",
|
||
" <td>33.660</td>\n",
|
||
" <td>4</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>12949.15540</td>\n",
|
||
" <td>14977.306757</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>307</th>\n",
|
||
" <td>30</td>\n",
|
||
" <td>33.330</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>4151.02870</td>\n",
|
||
" <td>5778.492115</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>840</th>\n",
|
||
" <td>21</td>\n",
|
||
" <td>31.100</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1526.31200</td>\n",
|
||
" <td>3324.843009</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2090</th>\n",
|
||
" <td>47</td>\n",
|
||
" <td>29.545</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>8930.93455</td>\n",
|
||
" <td>11318.629065</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" age bmi children sex_male region_northwest region_southeast \\\n",
|
||
"1101 53 28.600 3 1.0 0.0 0.0 \n",
|
||
"2025 56 33.660 4 1.0 0.0 1.0 \n",
|
||
"307 30 33.330 1 0.0 0.0 1.0 \n",
|
||
"840 21 31.100 0 1.0 0.0 0.0 \n",
|
||
"2090 47 29.545 1 0.0 1.0 0.0 \n",
|
||
"\n",
|
||
" region_southwest smoker_yes charges ChargesPred \n",
|
||
"1101 1.0 0.0 11253.42100 12139.772544 \n",
|
||
"2025 0.0 0.0 12949.15540 14977.306757 \n",
|
||
"307 0.0 0.0 4151.02870 5778.492115 \n",
|
||
"840 1.0 0.0 1526.31200 3324.843009 \n",
|
||
"2090 0.0 0.0 8930.93455 11318.629065 "
|
||
]
|
||
},
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"pd.concat(\n",
|
||
" [\n",
|
||
" X_test,\n",
|
||
" y_test,\n",
|
||
" pd.Series(\n",
|
||
" models[best_model][\"preds\"],\n",
|
||
" index=y_test.index,\n",
|
||
" name=\"ChargesPred\",\n",
|
||
" ),\n",
|
||
" ],\n",
|
||
" axis=1,\n",
|
||
").head(5)"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": ".venv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.12.4"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|