MII_Salin_Oleg_PIbd-33/lec4_reg.ipynb

1594 lines
51 KiB
Plaintext
Raw Normal View History

2024-11-08 22:37:34 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Загрузка данных"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>bmi</th>\n",
" <th>children</th>\n",
" <th>charges</th>\n",
" <th>sex_male</th>\n",
" <th>region_northwest</th>\n",
" <th>region_southeast</th>\n",
" <th>region_southwest</th>\n",
" <th>smoker_yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>19</td>\n",
" <td>27.900</td>\n",
" <td>0</td>\n",
" <td>16884.92400</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>18</td>\n",
" <td>33.770</td>\n",
" <td>1</td>\n",
" <td>1725.55230</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>28</td>\n",
" <td>33.000</td>\n",
" <td>3</td>\n",
" <td>4449.46200</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>33</td>\n",
" <td>22.705</td>\n",
" <td>0</td>\n",
" <td>21984.47061</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>32</td>\n",
" <td>28.880</td>\n",
" <td>0</td>\n",
" <td>3866.85520</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2767</th>\n",
" <td>47</td>\n",
" <td>45.320</td>\n",
" <td>1</td>\n",
" <td>8569.86180</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2768</th>\n",
" <td>21</td>\n",
" <td>34.600</td>\n",
" <td>0</td>\n",
" <td>2020.17700</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2769</th>\n",
" <td>19</td>\n",
" <td>26.030</td>\n",
" <td>1</td>\n",
" <td>16450.89470</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2770</th>\n",
" <td>23</td>\n",
" <td>18.715</td>\n",
" <td>0</td>\n",
" <td>21595.38229</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2771</th>\n",
" <td>54</td>\n",
" <td>31.600</td>\n",
" <td>0</td>\n",
" <td>9850.43200</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2772 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" age bmi children charges sex_male region_northwest \\\n",
"0 19 27.900 0 16884.92400 0.0 0.0 \n",
"1 18 33.770 1 1725.55230 1.0 0.0 \n",
"2 28 33.000 3 4449.46200 1.0 0.0 \n",
"3 33 22.705 0 21984.47061 1.0 1.0 \n",
"4 32 28.880 0 3866.85520 1.0 1.0 \n",
"... ... ... ... ... ... ... \n",
"2767 47 45.320 1 8569.86180 0.0 0.0 \n",
"2768 21 34.600 0 2020.17700 0.0 0.0 \n",
"2769 19 26.030 1 16450.89470 1.0 1.0 \n",
"2770 23 18.715 0 21595.38229 1.0 1.0 \n",
"2771 54 31.600 0 9850.43200 1.0 0.0 \n",
"\n",
" region_southeast region_southwest smoker_yes \n",
"0 0.0 1.0 1.0 \n",
"1 1.0 0.0 0.0 \n",
"2 1.0 0.0 0.0 \n",
"3 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 \n",
"... ... ... ... \n",
"2767 1.0 0.0 0.0 \n",
"2768 0.0 1.0 0.0 \n",
"2769 0.0 0.0 1.0 \n",
"2770 0.0 0.0 0.0 \n",
"2771 0.0 1.0 0.0 \n",
"\n",
"[2772 rows x 9 columns]"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"import numpy as np # type: ignore\n",
"\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"random_state = 9\n",
"\n",
"df = pd.read_csv(\"data/Medical_insurance.csv\", index_col=False)\n",
"\n",
"encoder = OneHotEncoder(sparse_output=False, drop=\"first\")\n",
"\n",
"encoded_values = encoder.fit_transform(df[[\"sex\", \"region\", \"smoker\"]])\n",
"\n",
"encoded_columns = encoder.get_feature_names_out([\"sex\", \"region\", \"smoker\"])\n",
"\n",
"encoded_values_df = pd.DataFrame(encoded_values, columns=encoded_columns)\n",
"\n",
"df = pd.concat([df, encoded_values_df], axis=1)\n",
"\n",
"df = df.drop([\"sex\", \"smoker\", \"region\"], axis=1)\n",
"\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование выборок"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>bmi</th>\n",
" <th>children</th>\n",
" <th>sex_male</th>\n",
" <th>region_northwest</th>\n",
" <th>region_southeast</th>\n",
" <th>region_southwest</th>\n",
" <th>smoker_yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2146</th>\n",
" <td>22</td>\n",
" <td>34.580</td>\n",
" <td>2</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>472</th>\n",
" <td>19</td>\n",
" <td>29.800</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>801</th>\n",
" <td>64</td>\n",
" <td>35.970</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84</th>\n",
" <td>37</td>\n",
" <td>34.800</td>\n",
" <td>2</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2028</th>\n",
" <td>61</td>\n",
" <td>33.915</td>\n",
" <td>0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>979</th>\n",
" <td>36</td>\n",
" <td>29.920</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1475</th>\n",
" <td>55</td>\n",
" <td>26.980</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2547</th>\n",
" <td>34</td>\n",
" <td>42.130</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2553</th>\n",
" <td>29</td>\n",
" <td>24.600</td>\n",
" <td>2</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1974</th>\n",
" <td>61</td>\n",
" <td>35.910</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2217 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" age bmi children sex_male region_northwest region_southeast \\\n",
"2146 22 34.580 2 0.0 0.0 0.0 \n",
"472 19 29.800 0 0.0 0.0 0.0 \n",
"801 64 35.970 0 0.0 0.0 1.0 \n",
"84 37 34.800 2 0.0 0.0 0.0 \n",
"2028 61 33.915 0 1.0 0.0 0.0 \n",
"... ... ... ... ... ... ... \n",
"979 36 29.920 0 0.0 0.0 1.0 \n",
"1475 55 26.980 0 0.0 1.0 0.0 \n",
"2547 34 42.130 2 1.0 0.0 1.0 \n",
"2553 29 24.600 2 0.0 0.0 0.0 \n",
"1974 61 35.910 0 0.0 0.0 0.0 \n",
"\n",
" region_southwest smoker_yes \n",
"2146 0.0 0.0 \n",
"472 1.0 0.0 \n",
"801 0.0 0.0 \n",
"84 1.0 1.0 \n",
"2028 0.0 0.0 \n",
"... ... ... \n",
"979 0.0 0.0 \n",
"1475 0.0 0.0 \n",
"2547 0.0 0.0 \n",
"2553 1.0 0.0 \n",
"1974 0.0 0.0 \n",
"\n",
"[2217 rows x 8 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>charges</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2146</th>\n",
" <td>3925.75820</td>\n",
" </tr>\n",
" <tr>\n",
" <th>472</th>\n",
" <td>1744.46500</td>\n",
" </tr>\n",
" <tr>\n",
" <th>801</th>\n",
" <td>14313.84630</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84</th>\n",
" <td>39836.51900</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2028</th>\n",
" <td>13143.86485</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>979</th>\n",
" <td>4889.03680</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1475</th>\n",
" <td>11082.57720</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2547</th>\n",
" <td>5124.18870</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2553</th>\n",
" <td>4529.47700</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1974</th>\n",
" <td>13635.63790</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2217 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" charges\n",
"2146 3925.75820\n",
"472 1744.46500\n",
"801 14313.84630\n",
"84 39836.51900\n",
"2028 13143.86485\n",
"... ...\n",
"979 4889.03680\n",
"1475 11082.57720\n",
"2547 5124.18870\n",
"2553 4529.47700\n",
"1974 13635.63790\n",
"\n",
"[2217 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>bmi</th>\n",
" <th>children</th>\n",
" <th>sex_male</th>\n",
" <th>region_northwest</th>\n",
" <th>region_southeast</th>\n",
" <th>region_southwest</th>\n",
" <th>smoker_yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1101</th>\n",
" <td>53</td>\n",
" <td>28.600</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2025</th>\n",
" <td>56</td>\n",
" <td>33.660</td>\n",
" <td>4</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>307</th>\n",
" <td>30</td>\n",
" <td>33.330</td>\n",
" <td>1</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>840</th>\n",
" <td>21</td>\n",
" <td>31.100</td>\n",
" <td>0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2090</th>\n",
" <td>47</td>\n",
" <td>29.545</td>\n",
" <td>1</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1587</th>\n",
" <td>48</td>\n",
" <td>32.230</td>\n",
" <td>1</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1177</th>\n",
" <td>40</td>\n",
" <td>27.400</td>\n",
" <td>1</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1259</th>\n",
" <td>52</td>\n",
" <td>23.180</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1291</th>\n",
" <td>19</td>\n",
" <td>34.900</td>\n",
" <td>0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2040</th>\n",
" <td>59</td>\n",
" <td>35.200</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>555 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" age bmi children sex_male region_northwest region_southeast \\\n",
"1101 53 28.600 3 1.0 0.0 0.0 \n",
"2025 56 33.660 4 1.0 0.0 1.0 \n",
"307 30 33.330 1 0.0 0.0 1.0 \n",
"840 21 31.100 0 1.0 0.0 0.0 \n",
"2090 47 29.545 1 0.0 1.0 0.0 \n",
"... ... ... ... ... ... ... \n",
"1587 48 32.230 1 0.0 0.0 1.0 \n",
"1177 40 27.400 1 0.0 0.0 0.0 \n",
"1259 52 23.180 0 0.0 0.0 0.0 \n",
"1291 19 34.900 0 1.0 0.0 0.0 \n",
"2040 59 35.200 0 0.0 0.0 1.0 \n",
"\n",
" region_southwest smoker_yes \n",
"1101 1.0 0.0 \n",
"2025 0.0 0.0 \n",
"307 0.0 0.0 \n",
"840 1.0 0.0 \n",
"2090 0.0 0.0 \n",
"... ... ... \n",
"1587 0.0 0.0 \n",
"1177 1.0 0.0 \n",
"1259 0.0 0.0 \n",
"1291 1.0 1.0 \n",
"2040 0.0 0.0 \n",
"\n",
"[555 rows x 8 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>charges</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1101</th>\n",
" <td>11253.42100</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2025</th>\n",
" <td>12949.15540</td>\n",
" </tr>\n",
" <tr>\n",
" <th>307</th>\n",
" <td>4151.02870</td>\n",
" </tr>\n",
" <tr>\n",
" <th>840</th>\n",
" <td>1526.31200</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2090</th>\n",
" <td>8930.93455</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1587</th>\n",
" <td>8871.15170</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1177</th>\n",
" <td>6496.88600</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1259</th>\n",
" <td>10197.77220</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1291</th>\n",
" <td>34828.65400</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2040</th>\n",
" <td>12244.53100</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>555 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" charges\n",
"1101 11253.42100\n",
"2025 12949.15540\n",
"307 4151.02870\n",
"840 1526.31200\n",
"2090 8930.93455\n",
"... ...\n",
"1587 8871.15170\n",
"1177 6496.88600\n",
"1259 10197.77220\n",
"1291 34828.65400\n",
"2040 12244.53100\n",
"\n",
"[555 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from utils import split_stratified_into_train_val_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df,\n",
" stratify_colname=\"age\",\n",
" target_colname=\"charges\",\n",
" frac_train=0.80,\n",
" frac_val=0,\n",
" frac_test=0.20,\n",
" random_state=random_state,\n",
")\n",
"\n",
"X_train = X_train.drop([\"charges\"], axis=1)\n",
"X_test = X_test.drop([\"charges\"], axis=1)\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Определение перечня алгоритмов решения задачи аппроксимации (регрессии)"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
"\n",
"random_state = 9\n",
"\n",
"models = {\n",
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPRegressor(\n",
" activation=\"tanh\",\n",
" hidden_layer_sizes=(3,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Определение функции для стандартизации числовых значений для MLP"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
"from pandas import DataFrame\n",
"from sklearn import preprocessing\n",
"\n",
"\n",
"stndart_scaler = preprocessing.StandardScaler()\n",
"\n",
"\n",
"def std_q(df: DataFrame) -> DataFrame:\n",
" df[\"age\"] = np.array(stndart_scaler.fit_transform(df[\"age\"].to_numpy().reshape(-1, 1))).reshape(\n",
" df[\"age\"].shape\n",
" )\n",
" df[\"bmi\"] = np.array(\n",
" stndart_scaler.fit_transform(df[\"bmi\"].to_numpy().reshape(-1, 1))\n",
" ).reshape(df[\"bmi\"].shape)\n",
" df[\"children\"] = np.array(\n",
" stndart_scaler.fit_transform(df[\"children\"].to_numpy().reshape(-1, 1))\n",
" ).reshape(df[\"children\"].shape)\n",
" return df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение и оценка моделей с помощью различных алгоритмов"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: linear\n",
"Model: linear_poly\n",
"Model: linear_interact\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import math\n",
"from pandas import DataFrame\n",
"from sklearn import metrics\n",
"\n",
"for model_name in models.keys():\n",
" print(f\"Model: {model_name}\")\n",
"\n",
" x_train: DataFrame = X_train.copy()\n",
" x_test: DataFrame = X_test.copy()\n",
"\n",
" if model_name == \"mlp\":\n",
" x_train = std_q(x_train)\n",
" x_test = std_q(x_test)\n",
" \n",
" fitted_model = models[model_name][\"model\"].fit(\n",
" x_train.values, y_train.values.ravel()\n",
" )\n",
" y_train_pred = fitted_model.predict(x_train.values)\n",
" y_test_pred = fitted_model.predict(x_test.values)\n",
" models[model_name][\"fitted\"] = fitted_model\n",
" models[model_name][\"train_preds\"] = y_train_pred\n",
" models[model_name][\"preds\"] = y_test_pred\n",
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_train, y_train_pred)\n",
" )\n",
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод результатов оценки"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_995df_row0_col0, #T_995df_row0_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row0_col2, #T_995df_row7_col3 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row0_col3, #T_995df_row7_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row1_col0 {\n",
" background-color: #25838e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row1_col1 {\n",
" background-color: #26828e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row1_col2 {\n",
" background-color: #5102a3;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row1_col3 {\n",
" background-color: #d9586a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row2_col0 {\n",
" background-color: #228b8d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row2_col1, #T_995df_row3_col1 {\n",
" background-color: #24878e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row2_col2 {\n",
" background-color: #6300a7;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row2_col3, #T_995df_row3_col3 {\n",
" background-color: #d7566c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row3_col0 {\n",
" background-color: #228c8d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row3_col2 {\n",
" background-color: #6400a7;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row4_col0, #T_995df_row5_col0 {\n",
" background-color: #1f948c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row4_col1, #T_995df_row5_col1 {\n",
" background-color: #21918c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row4_col2, #T_995df_row5_col2 {\n",
" background-color: #7e03a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row4_col3, #T_995df_row5_col3 {\n",
" background-color: #d35171;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row6_col0 {\n",
" background-color: #20a486;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row6_col1 {\n",
" background-color: #24aa83;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row6_col2 {\n",
" background-color: #a01a9c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row6_col3 {\n",
" background-color: #c13b82;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row7_col0, #T_995df_row7_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"</style>\n",
"<table id=\"T_995df\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_995df_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
" <th id=\"T_995df_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
" <th id=\"T_995df_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
" <th id=\"T_995df_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_995df_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
" <td id=\"T_995df_row0_col0\" class=\"data row0 col0\" >3221.469707</td>\n",
" <td id=\"T_995df_row0_col1\" class=\"data row0 col1\" >3953.661053</td>\n",
" <td id=\"T_995df_row0_col2\" class=\"data row0 col2\" >45.741609</td>\n",
" <td id=\"T_995df_row0_col3\" class=\"data row0 col3\" >0.901103</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_995df_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
" <td id=\"T_995df_row1_col0\" class=\"data row1 col0\" >3643.279193</td>\n",
" <td id=\"T_995df_row1_col1\" class=\"data row1 col1\" >4288.040726</td>\n",
" <td id=\"T_995df_row1_col2\" class=\"data row1 col2\" >47.359073</td>\n",
" <td id=\"T_995df_row1_col3\" class=\"data row1 col3\" >0.883668</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_995df_level0_row2\" class=\"row_heading level0 row2\" >linear_poly</th>\n",
" <td id=\"T_995df_row2_col0\" class=\"data row2 col0\" >4731.024654</td>\n",
" <td id=\"T_995df_row2_col1\" class=\"data row2 col1\" >4868.817371</td>\n",
" <td id=\"T_995df_row2_col2\" class=\"data row2 col2\" >54.257745</td>\n",
" <td id=\"T_995df_row2_col3\" class=\"data row2 col3\" >0.850021</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_995df_level0_row3\" class=\"row_heading level0 row3\" >linear_interact</th>\n",
" <td id=\"T_995df_row3_col0\" class=\"data row3 col0\" >4776.393716</td>\n",
" <td id=\"T_995df_row3_col1\" class=\"data row3 col1\" >4938.699556</td>\n",
" <td id=\"T_995df_row3_col2\" class=\"data row3 col2\" >54.641209</td>\n",
" <td id=\"T_995df_row3_col3\" class=\"data row3 col3\" >0.845685</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_995df_level0_row4\" class=\"row_heading level0 row4\" >ridge</th>\n",
" <td id=\"T_995df_row4_col0\" class=\"data row4 col0\" >6028.427617</td>\n",
" <td id=\"T_995df_row4_col1\" class=\"data row4 col1\" >6216.544081</td>\n",
" <td id=\"T_995df_row4_col2\" class=\"data row4 col2\" >65.584948</td>\n",
" <td id=\"T_995df_row4_col3\" class=\"data row4 col3\" >0.755499</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_995df_level0_row5\" class=\"row_heading level0 row5\" >linear</th>\n",
" <td id=\"T_995df_row5_col0\" class=\"data row5 col0\" >6028.426993</td>\n",
" <td id=\"T_995df_row5_col1\" class=\"data row5 col1\" >6216.588829</td>\n",
" <td id=\"T_995df_row5_col2\" class=\"data row5 col2\" >65.580879</td>\n",
" <td id=\"T_995df_row5_col3\" class=\"data row5 col3\" >0.755496</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_995df_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
" <td id=\"T_995df_row6_col0\" class=\"data row6 col0\" >8230.959070</td>\n",
" <td id=\"T_995df_row6_col1\" class=\"data row6 col1\" >9715.102581</td>\n",
" <td id=\"T_995df_row6_col2\" class=\"data row6 col2\" >81.129201</td>\n",
" <td id=\"T_995df_row6_col3\" class=\"data row6 col3\" >0.402859</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_995df_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_995df_row7_col0\" class=\"data row7 col0\" >17848.198895</td>\n",
" <td id=\"T_995df_row7_col1\" class=\"data row7 col1\" >18518.275054</td>\n",
" <td id=\"T_995df_row7_col2\" class=\"data row7 col2\" >116.605174</td>\n",
" <td id=\"T_995df_row7_col3\" class=\"data row7 col3\" >-1.169619</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1d2bcd71160>"
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод реального и \"спрогнозированного\" результата для обучающей и тестовой выборок"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Получение лучшей модели"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'random_forest'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод для обучающей выборки"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>bmi</th>\n",
" <th>children</th>\n",
" <th>sex_male</th>\n",
" <th>region_northwest</th>\n",
" <th>region_southeast</th>\n",
" <th>region_southwest</th>\n",
" <th>smoker_yes</th>\n",
" <th>charges</th>\n",
" <th>ChargesPred</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2146</th>\n",
" <td>22</td>\n",
" <td>34.580</td>\n",
" <td>2</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>3925.75820</td>\n",
" <td>4868.448184</td>\n",
" </tr>\n",
" <tr>\n",
" <th>472</th>\n",
" <td>19</td>\n",
" <td>29.800</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1744.46500</td>\n",
" <td>3011.805771</td>\n",
" </tr>\n",
" <tr>\n",
" <th>801</th>\n",
" <td>64</td>\n",
" <td>35.970</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>14313.84630</td>\n",
" <td>13841.766282</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84</th>\n",
" <td>37</td>\n",
" <td>34.800</td>\n",
" <td>2</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>39836.51900</td>\n",
" <td>39427.673528</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2028</th>\n",
" <td>61</td>\n",
" <td>33.915</td>\n",
" <td>0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>13143.86485</td>\n",
" <td>13575.291528</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age bmi children sex_male region_northwest region_southeast \\\n",
"2146 22 34.580 2 0.0 0.0 0.0 \n",
"472 19 29.800 0 0.0 0.0 0.0 \n",
"801 64 35.970 0 0.0 0.0 1.0 \n",
"84 37 34.800 2 0.0 0.0 0.0 \n",
"2028 61 33.915 0 1.0 0.0 0.0 \n",
"\n",
" region_southwest smoker_yes charges ChargesPred \n",
"2146 0.0 0.0 3925.75820 4868.448184 \n",
"472 1.0 0.0 1744.46500 3011.805771 \n",
"801 0.0 0.0 14313.84630 13841.766282 \n",
"84 1.0 1.0 39836.51900 39427.673528 \n",
"2028 0.0 0.0 13143.86485 13575.291528 "
]
},
"execution_count": 81,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.concat(\n",
" [\n",
" X_train,\n",
" y_train,\n",
" pd.Series(\n",
" models[best_model][\"train_preds\"],\n",
" index=y_train.index,\n",
" name=\"ChargesPred\",\n",
" ),\n",
" ],\n",
" axis=1,\n",
").head(5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод для тестовой выборки"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>bmi</th>\n",
" <th>children</th>\n",
" <th>sex_male</th>\n",
" <th>region_northwest</th>\n",
" <th>region_southeast</th>\n",
" <th>region_southwest</th>\n",
" <th>smoker_yes</th>\n",
" <th>charges</th>\n",
" <th>ChargesPred</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1101</th>\n",
" <td>53</td>\n",
" <td>28.600</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>11253.42100</td>\n",
" <td>12139.772544</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2025</th>\n",
" <td>56</td>\n",
" <td>33.660</td>\n",
" <td>4</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>12949.15540</td>\n",
" <td>14977.306757</td>\n",
" </tr>\n",
" <tr>\n",
" <th>307</th>\n",
" <td>30</td>\n",
" <td>33.330</td>\n",
" <td>1</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>4151.02870</td>\n",
" <td>5778.492115</td>\n",
" </tr>\n",
" <tr>\n",
" <th>840</th>\n",
" <td>21</td>\n",
" <td>31.100</td>\n",
" <td>0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1526.31200</td>\n",
" <td>3324.843009</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2090</th>\n",
" <td>47</td>\n",
" <td>29.545</td>\n",
" <td>1</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>8930.93455</td>\n",
" <td>11318.629065</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age bmi children sex_male region_northwest region_southeast \\\n",
"1101 53 28.600 3 1.0 0.0 0.0 \n",
"2025 56 33.660 4 1.0 0.0 1.0 \n",
"307 30 33.330 1 0.0 0.0 1.0 \n",
"840 21 31.100 0 1.0 0.0 0.0 \n",
"2090 47 29.545 1 0.0 1.0 0.0 \n",
"\n",
" region_southwest smoker_yes charges ChargesPred \n",
"1101 1.0 0.0 11253.42100 12139.772544 \n",
"2025 0.0 0.0 12949.15540 14977.306757 \n",
"307 0.0 0.0 4151.02870 5778.492115 \n",
"840 1.0 0.0 1526.31200 3324.843009 \n",
"2090 0.0 0.0 8930.93455 11318.629065 "
]
},
"execution_count": 82,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.concat(\n",
" [\n",
" X_test,\n",
" y_test,\n",
" pd.Series(\n",
" models[best_model][\"preds\"],\n",
" index=y_test.index,\n",
" name=\"ChargesPred\",\n",
" ),\n",
" ],\n",
" axis=1,\n",
").head(5)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}