{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "#### Загрузка данных" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agebmichildrenchargessex_maleregion_northwestregion_southeastregion_southwestsmoker_yes
01927.900016884.924000.00.00.01.01.0
11833.77011725.552301.00.01.00.00.0
22833.00034449.462001.00.01.00.00.0
33322.705021984.470611.01.00.00.00.0
43228.88003866.855201.01.00.00.00.0
..............................
27674745.32018569.861800.00.01.00.00.0
27682134.60002020.177000.00.00.01.00.0
27691926.030116450.894701.01.00.00.01.0
27702318.715021595.382291.01.00.00.00.0
27715431.60009850.432001.00.00.01.00.0
\n", "

2772 rows × 9 columns

\n", "
" ], "text/plain": [ " age bmi children charges sex_male region_northwest \\\n", "0 19 27.900 0 16884.92400 0.0 0.0 \n", "1 18 33.770 1 1725.55230 1.0 0.0 \n", "2 28 33.000 3 4449.46200 1.0 0.0 \n", "3 33 22.705 0 21984.47061 1.0 1.0 \n", "4 32 28.880 0 3866.85520 1.0 1.0 \n", "... ... ... ... ... ... ... \n", "2767 47 45.320 1 8569.86180 0.0 0.0 \n", "2768 21 34.600 0 2020.17700 0.0 0.0 \n", "2769 19 26.030 1 16450.89470 1.0 1.0 \n", "2770 23 18.715 0 21595.38229 1.0 1.0 \n", "2771 54 31.600 0 9850.43200 1.0 0.0 \n", "\n", " region_southeast region_southwest smoker_yes \n", "0 0.0 1.0 1.0 \n", "1 1.0 0.0 0.0 \n", "2 1.0 0.0 0.0 \n", "3 0.0 0.0 0.0 \n", "4 0.0 0.0 0.0 \n", "... ... ... ... \n", "2767 1.0 0.0 0.0 \n", "2768 0.0 1.0 0.0 \n", "2769 0.0 0.0 1.0 \n", "2770 0.0 0.0 0.0 \n", "2771 0.0 1.0 0.0 \n", "\n", "[2772 rows x 9 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "from sklearn.preprocessing import OneHotEncoder\n", "import numpy as np # type: ignore\n", "\n", "from sklearn import set_config\n", "\n", "set_config(transform_output=\"pandas\")\n", "\n", "random_state = 9\n", "\n", "df = pd.read_csv(\"data/Medical_insurance.csv\", index_col=False)\n", "\n", "encoder = OneHotEncoder(sparse_output=False, drop=\"first\")\n", "\n", "encoded_values = encoder.fit_transform(df[[\"sex\", \"region\", \"smoker\"]])\n", "\n", "encoded_columns = encoder.get_feature_names_out([\"sex\", \"region\", \"smoker\"])\n", "\n", "encoded_values_df = pd.DataFrame(encoded_values, columns=encoded_columns)\n", "\n", "df = pd.concat([df, encoded_values_df], axis=1)\n", "\n", "df = df.drop([\"sex\", \"smoker\", \"region\"], axis=1)\n", "\n", "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Формирование выборок" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'X_train'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agebmichildrensex_maleregion_northwestregion_southeastregion_southwestsmoker_yes
21462234.58020.00.00.00.00.0
4721929.80000.00.00.01.00.0
8016435.97000.00.01.00.00.0
843734.80020.00.00.01.01.0
20286133.91501.00.00.00.00.0
...........................
9793629.92000.00.01.00.00.0
14755526.98000.01.00.00.00.0
25473442.13021.00.01.00.00.0
25532924.60020.00.00.01.00.0
19746135.91000.00.00.00.00.0
\n", "

2217 rows × 8 columns

\n", "
" ], "text/plain": [ " age bmi children sex_male region_northwest region_southeast \\\n", "2146 22 34.580 2 0.0 0.0 0.0 \n", "472 19 29.800 0 0.0 0.0 0.0 \n", "801 64 35.970 0 0.0 0.0 1.0 \n", "84 37 34.800 2 0.0 0.0 0.0 \n", "2028 61 33.915 0 1.0 0.0 0.0 \n", "... ... ... ... ... ... ... \n", "979 36 29.920 0 0.0 0.0 1.0 \n", "1475 55 26.980 0 0.0 1.0 0.0 \n", "2547 34 42.130 2 1.0 0.0 1.0 \n", "2553 29 24.600 2 0.0 0.0 0.0 \n", "1974 61 35.910 0 0.0 0.0 0.0 \n", "\n", " region_southwest smoker_yes \n", "2146 0.0 0.0 \n", "472 1.0 0.0 \n", "801 0.0 0.0 \n", "84 1.0 1.0 \n", "2028 0.0 0.0 \n", "... ... ... \n", "979 0.0 0.0 \n", "1475 0.0 0.0 \n", "2547 0.0 0.0 \n", "2553 1.0 0.0 \n", "1974 0.0 0.0 \n", "\n", "[2217 rows x 8 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'y_train'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
charges
21463925.75820
4721744.46500
80114313.84630
8439836.51900
202813143.86485
......
9794889.03680
147511082.57720
25475124.18870
25534529.47700
197413635.63790
\n", "

2217 rows × 1 columns

\n", "
" ], "text/plain": [ " charges\n", "2146 3925.75820\n", "472 1744.46500\n", "801 14313.84630\n", "84 39836.51900\n", "2028 13143.86485\n", "... ...\n", "979 4889.03680\n", "1475 11082.57720\n", "2547 5124.18870\n", "2553 4529.47700\n", "1974 13635.63790\n", "\n", "[2217 rows x 1 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'X_test'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agebmichildrensex_maleregion_northwestregion_southeastregion_southwestsmoker_yes
11015328.60031.00.00.01.00.0
20255633.66041.00.01.00.00.0
3073033.33010.00.01.00.00.0
8402131.10001.00.00.01.00.0
20904729.54510.01.00.00.00.0
...........................
15874832.23010.00.01.00.00.0
11774027.40010.00.00.01.00.0
12595223.18000.00.00.00.00.0
12911934.90001.00.00.01.01.0
20405935.20000.00.01.00.00.0
\n", "

555 rows × 8 columns

\n", "
" ], "text/plain": [ " age bmi children sex_male region_northwest region_southeast \\\n", "1101 53 28.600 3 1.0 0.0 0.0 \n", "2025 56 33.660 4 1.0 0.0 1.0 \n", "307 30 33.330 1 0.0 0.0 1.0 \n", "840 21 31.100 0 1.0 0.0 0.0 \n", "2090 47 29.545 1 0.0 1.0 0.0 \n", "... ... ... ... ... ... ... \n", "1587 48 32.230 1 0.0 0.0 1.0 \n", "1177 40 27.400 1 0.0 0.0 0.0 \n", "1259 52 23.180 0 0.0 0.0 0.0 \n", "1291 19 34.900 0 1.0 0.0 0.0 \n", "2040 59 35.200 0 0.0 0.0 1.0 \n", "\n", " region_southwest smoker_yes \n", "1101 1.0 0.0 \n", "2025 0.0 0.0 \n", "307 0.0 0.0 \n", "840 1.0 0.0 \n", "2090 0.0 0.0 \n", "... ... ... \n", "1587 0.0 0.0 \n", "1177 1.0 0.0 \n", "1259 0.0 0.0 \n", "1291 1.0 1.0 \n", "2040 0.0 0.0 \n", "\n", "[555 rows x 8 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'y_test'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
charges
110111253.42100
202512949.15540
3074151.02870
8401526.31200
20908930.93455
......
15878871.15170
11776496.88600
125910197.77220
129134828.65400
204012244.53100
\n", "

555 rows × 1 columns

\n", "
" ], "text/plain": [ " charges\n", "1101 11253.42100\n", "2025 12949.15540\n", "307 4151.02870\n", "840 1526.31200\n", "2090 8930.93455\n", "... ...\n", "1587 8871.15170\n", "1177 6496.88600\n", "1259 10197.77220\n", "1291 34828.65400\n", "2040 12244.53100\n", "\n", "[555 rows x 1 columns]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from utils import split_stratified_into_train_val_test\n", "\n", "X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n", " df,\n", " stratify_colname=\"age\",\n", " target_colname=\"charges\",\n", " frac_train=0.80,\n", " frac_val=0,\n", " frac_test=0.20,\n", " random_state=random_state,\n", ")\n", "\n", "X_train = X_train.drop([\"charges\"], axis=1)\n", "X_test = X_test.drop([\"charges\"], axis=1)\n", "\n", "display(\"X_train\", X_train)\n", "display(\"y_train\", y_train)\n", "\n", "display(\"X_test\", X_test)\n", "display(\"y_test\", y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Определение перечня алгоритмов решения задачи аппроксимации (регрессии)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from sklearn.pipeline import make_pipeline\n", "from sklearn.preprocessing import PolynomialFeatures\n", "from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n", "\n", "random_state = 9\n", "\n", "models = {\n", " \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n", " \"linear_poly\": {\n", " \"model\": make_pipeline(\n", " PolynomialFeatures(degree=2),\n", " linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n", " )\n", " },\n", " \"linear_interact\": {\n", " \"model\": make_pipeline(\n", " PolynomialFeatures(interaction_only=True),\n", " linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n", " )\n", " },\n", " \"ridge\": {\"model\": linear_model.RidgeCV()},\n", " \"decision_tree\": {\n", " \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n", " },\n", " \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n", " \"random_forest\": {\n", " \"model\": ensemble.RandomForestRegressor(\n", " max_depth=7, random_state=random_state, n_jobs=-1\n", " )\n", " },\n", " \"mlp\": {\n", " \"model\": neural_network.MLPRegressor(\n", " activation=\"tanh\",\n", " hidden_layer_sizes=(3,),\n", " max_iter=500,\n", " early_stopping=True,\n", " random_state=random_state,\n", " )\n", " },\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Определение функции для стандартизации числовых значений для MLP" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from pandas import DataFrame\n", "from sklearn import preprocessing\n", "\n", "\n", "stndart_scaler = preprocessing.StandardScaler()\n", "\n", "\n", "def std_q(df: DataFrame) -> DataFrame:\n", " df[\"age\"] = np.array(stndart_scaler.fit_transform(df[\"age\"].to_numpy().reshape(-1, 1))).reshape(\n", " df[\"age\"].shape\n", " )\n", " df[\"bmi\"] = np.array(\n", " stndart_scaler.fit_transform(df[\"bmi\"].to_numpy().reshape(-1, 1))\n", " ).reshape(df[\"bmi\"].shape)\n", " df[\"children\"] = np.array(\n", " stndart_scaler.fit_transform(df[\"children\"].to_numpy().reshape(-1, 1))\n", " ).reshape(df[\"children\"].shape)\n", " return df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Обучение и оценка моделей с помощью различных алгоритмов" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: linear\n", "Model: linear_poly\n", "Model: linear_interact\n", "Model: ridge\n", "Model: decision_tree\n", "Model: knn\n", "Model: random_forest\n", "Model: mlp\n" ] } ], "source": [ "import math\n", "from pandas import DataFrame\n", "from sklearn import metrics\n", "\n", "for model_name in models.keys():\n", " print(f\"Model: {model_name}\")\n", "\n", " x_train: DataFrame = X_train.copy()\n", " x_test: DataFrame = X_test.copy()\n", "\n", " if model_name == \"mlp\":\n", " x_train = std_q(x_train)\n", " x_test = std_q(x_test)\n", " \n", " fitted_model = models[model_name][\"model\"].fit(\n", " x_train.values, y_train.values.ravel()\n", " )\n", " y_train_pred = fitted_model.predict(x_train.values)\n", " y_test_pred = fitted_model.predict(x_test.values)\n", " models[model_name][\"fitted\"] = fitted_model\n", " models[model_name][\"train_preds\"] = y_train_pred\n", " models[model_name][\"preds\"] = y_test_pred\n", " models[model_name][\"RMSE_train\"] = math.sqrt(\n", " metrics.mean_squared_error(y_train, y_train_pred)\n", " )\n", " models[model_name][\"RMSE_test\"] = math.sqrt(\n", " metrics.mean_squared_error(y_test, y_test_pred)\n", " )\n", " models[model_name][\"RMAE_test\"] = math.sqrt(\n", " metrics.mean_absolute_error(y_test, y_test_pred)\n", " )\n", " models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Вывод результатов оценки" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 RMSE_trainRMSE_testRMAE_testR2_test
random_forest3221.4697073953.66105345.7416090.901103
decision_tree3643.2791934288.04072647.3590730.883668
linear_poly4731.0246544868.81737154.2577450.850021
linear_interact4776.3937164938.69955654.6412090.845685
ridge6028.4276176216.54408165.5849480.755499
linear6028.4269936216.58882965.5808790.755496
knn8230.9590709715.10258181.1292010.402859
mlp17848.19889518518.275054116.605174-1.169619
\n" ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n", " [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n", "]\n", "reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient( # type: ignore\n", " cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n", ").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Вывод реального и \"спрогнозированного\" результата для обучающей и тестовой выборок" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Получение лучшей модели" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'random_forest'" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name) # type: ignore\n", "\n", "display(best_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Вывод для обучающей выборки" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agebmichildrensex_maleregion_northwestregion_southeastregion_southwestsmoker_yeschargesChargesPred
21462234.58020.00.00.00.00.03925.758204868.448184
4721929.80000.00.00.01.00.01744.465003011.805771
8016435.97000.00.01.00.00.014313.8463013841.766282
843734.80020.00.00.01.01.039836.5190039427.673528
20286133.91501.00.00.00.00.013143.8648513575.291528
\n", "
" ], "text/plain": [ " age bmi children sex_male region_northwest region_southeast \\\n", "2146 22 34.580 2 0.0 0.0 0.0 \n", "472 19 29.800 0 0.0 0.0 0.0 \n", "801 64 35.970 0 0.0 0.0 1.0 \n", "84 37 34.800 2 0.0 0.0 0.0 \n", "2028 61 33.915 0 1.0 0.0 0.0 \n", "\n", " region_southwest smoker_yes charges ChargesPred \n", "2146 0.0 0.0 3925.75820 4868.448184 \n", "472 1.0 0.0 1744.46500 3011.805771 \n", "801 0.0 0.0 14313.84630 13841.766282 \n", "84 1.0 1.0 39836.51900 39427.673528 \n", "2028 0.0 0.0 13143.86485 13575.291528 " ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.concat(\n", " [\n", " X_train,\n", " y_train,\n", " pd.Series(\n", " models[best_model][\"train_preds\"],\n", " index=y_train.index,\n", " name=\"ChargesPred\",\n", " ),\n", " ],\n", " axis=1,\n", ").head(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Вывод для тестовой выборки" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agebmichildrensex_maleregion_northwestregion_southeastregion_southwestsmoker_yeschargesChargesPred
11015328.60031.00.00.01.00.011253.4210012139.772544
20255633.66041.00.01.00.00.012949.1554014977.306757
3073033.33010.00.01.00.00.04151.028705778.492115
8402131.10001.00.00.01.00.01526.312003324.843009
20904729.54510.01.00.00.00.08930.9345511318.629065
\n", "
" ], "text/plain": [ " age bmi children sex_male region_northwest region_southeast \\\n", "1101 53 28.600 3 1.0 0.0 0.0 \n", "2025 56 33.660 4 1.0 0.0 1.0 \n", "307 30 33.330 1 0.0 0.0 1.0 \n", "840 21 31.100 0 1.0 0.0 0.0 \n", "2090 47 29.545 1 0.0 1.0 0.0 \n", "\n", " region_southwest smoker_yes charges ChargesPred \n", "1101 1.0 0.0 11253.42100 12139.772544 \n", "2025 0.0 0.0 12949.15540 14977.306757 \n", "307 0.0 0.0 4151.02870 5778.492115 \n", "840 1.0 0.0 1526.31200 3324.843009 \n", "2090 0.0 0.0 8930.93455 11318.629065 " ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.concat(\n", " [\n", " X_test,\n", " y_test,\n", " pd.Series(\n", " models[best_model][\"preds\"],\n", " index=y_test.index,\n", " name=\"ChargesPred\",\n", " ),\n", " ],\n", " axis=1,\n", ").head(5)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 2 }