656 lines
299 KiB
Plaintext
Raw Normal View History

2024-12-14 03:08:44 +04:00
{
"cells": [
2024-12-14 11:56:36 +04:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Вариант: Список людей.\n",
"ссылка на датасет: https://www.kaggle.com/datasets/imoore/age-dataset"
]
},
2024-12-14 03:08:44 +04:00
{
"cell_type": "code",
2024-12-14 11:56:36 +04:00
"execution_count": 5,
2024-12-14 03:08:44 +04:00
"metadata": {},
"outputs": [
{
2024-12-14 11:56:36 +04:00
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 1000 entries, 0 to 999\n",
"Data columns (total 10 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 Id 1000 non-null object \n",
" 1 Name 1000 non-null object \n",
" 2 Short description 1000 non-null object \n",
" 3 Gender 995 non-null object \n",
" 4 Country 962 non-null object \n",
" 5 Occupation 998 non-null object \n",
" 6 Birth year 1000 non-null int64 \n",
" 7 Death year 999 non-null float64\n",
" 8 Manner of death 372 non-null object \n",
" 9 Age of death 999 non-null float64\n",
"dtypes: float64(2), int64(1), object(7)\n",
"memory usage: 78.3+ KB\n"
2024-12-14 03:08:44 +04:00
]
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from imblearn.over_sampling import RandomOverSampler\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
2024-12-14 11:56:36 +04:00
"from sklearn.preprocessing import StandardScaler\n",
2024-12-14 03:08:44 +04:00
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.linear_model import LinearRegression, LogisticRegression\n",
"from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, RandomForestClassifier, GradientBoostingClassifier\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"from sklearn.metrics import (\n",
" precision_score, recall_score, accuracy_score, roc_auc_score, f1_score,\n",
" matthews_corrcoef, cohen_kappa_score, confusion_matrix\n",
")\n",
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error\n",
2024-12-14 11:56:36 +04:00
"from sklearn.metrics import accuracy_score\n",
2024-12-14 03:08:44 +04:00
"\n",
"# Функция для применения oversampling\n",
"def apply_oversampling(X, y):\n",
" oversampler = RandomOverSampler(random_state=42)\n",
" X_resampled, y_resampled = oversampler.fit_resample(X, y)\n",
" return X_resampled, y_resampled\n",
"\n",
"# Функция для применения undersampling\n",
"def apply_undersampling(X, y):\n",
" undersampler = RandomUnderSampler(random_state=42)\n",
" X_resampled, y_resampled = undersampler.fit_resample(X, y)\n",
" return X_resampled, y_resampled\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
"):\n",
"\n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
"\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
"\n",
" X = df_input # Contains all columns.\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
"\n",
" # Split original dataframe into train and temp dataframes.\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
"\n",
" # Split the temp dataframe into val and test dataframes.\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
"\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
"\n",
" return df_train, df_val, df_test\n",
"\n",
"\n",
2024-12-14 11:56:36 +04:00
"df = pd.read_csv(\"../static/csv/AgeDataset-V1.csv\", nrows=1000)\n",
2024-12-14 03:08:44 +04:00
"df.info()"
]
2024-12-14 11:56:36 +04:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Как бизнес-цели выделим следующие 2 варианта: 1) GameDev. Создание игры про конкретного персонажа, живущего в конкретном временном промежутке в конкретной стране. 2) Классификация людей по возрастным группам, что может быть полезно для рекламных целей\n",
"\n",
"\n",
"Выполним подготовку данных\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"df.fillna({\"Gender\": \"NaN\", \"Country\": \"NaN\", \"Occupation\" : \"NaN\", \"Manner of death\" : \"NaN\"}, inplace=True)\n",
"df = df.dropna()\n",
"df['Country'] = df['Country'].str.split('; ')\n",
"df = df.explode('Country')\n",
"data = df.copy()\n",
"\n",
"\n",
"value_counts = data[\"Country\"].value_counts()\n",
"rare = value_counts[value_counts < 100].index\n",
"data = data[~data[\"Country\"].isin(rare)]\n",
"\n",
"data.drop(data[~data['Gender'].isin(['Male', 'Female'])].index, inplace=True)\n",
"\n",
"data1 = pd.get_dummies(data, columns=['Gender', 'Country', 'Occupation'], drop_first=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Определить достижимый уровень качества модели для каждой задачи. На основе имеющихся данных уровень качества моделей не будет высоким, поскольку все таки длительность жизни лишь примерная и точно ее угадать невозможно.\n",
"\n",
"Выберем ориентиры для наших 2х задач: 1)Регрессии - средний возраст человека 2)Классификации - аиболее часто встречающаяся возрастная группа\n",
"\n",
"Построим конвейер."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['Id', 'Name', 'Short description', 'Gender', 'Country', 'Occupation',\n",
" 'Birth year', 'Death year', 'Manner of death', 'Age of death'],\n",
" dtype='object')\n"
]
}
],
"source": [
"print(data.columns)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best parameters for Linear Regression: {}\n",
"Best parameters for Random Forest Regressor: {'model__max_depth': None, 'model__n_estimators': 200}\n",
"Best parameters for Gradient Boosting Regressor: {'model__learning_rate': 0.1, 'model__max_depth': 5, 'model__n_estimators': 300}\n"
]
}
],
"source": [
"X_reg = data1.drop(['Id', 'Name', 'Age of death', 'Short description', 'Manner of death'], axis=1)\n",
"y_reg = data1['Age of death']\n",
"\n",
"# Разделение данных\n",
"X_train_reg, X_test_reg, y_train_reg, y_test_reg = train_test_split(X_reg, y_reg, test_size=0.2, random_state=42)\n",
"\n",
"# Выбор моделей для регрессии\n",
"models_reg = {\n",
" 'Linear Regression': LinearRegression(),\n",
" 'Random Forest Regressor': RandomForestRegressor(random_state=42),\n",
" 'Gradient Boosting Regressor': GradientBoostingRegressor(random_state=42)\n",
"}\n",
"\n",
"# Создание конвейера для регрессии\n",
"pipelines_reg = {}\n",
"for name, model in models_reg.items():\n",
" pipelines_reg[name] = Pipeline([\n",
" ('scaler', StandardScaler()),\n",
" ('model', model)\n",
" ])\n",
"\n",
"# Определение сетки гиперпараметров для регрессии\n",
"param_grids_reg = {\n",
" 'Linear Regression': {},\n",
" 'Random Forest Regressor': {\n",
" 'model__n_estimators': [100, 200, 300],\n",
" 'model__max_depth': [None, 10, 20, 30]\n",
" },\n",
" 'Gradient Boosting Regressor': {\n",
" 'model__n_estimators': [100, 200, 300],\n",
" 'model__learning_rate': [0.01, 0.1, 0.2],\n",
" 'model__max_depth': [3, 5, 7]\n",
" }\n",
"}\n",
"\n",
"# Настройка гиперпараметров для регрессии\n",
"best_models_reg = {}\n",
"for name, pipeline in pipelines_reg.items():\n",
" grid_search = GridSearchCV(pipeline, param_grids_reg[name], cv=5, scoring='neg_mean_squared_error')\n",
" grid_search.fit(X_train_reg, y_train_reg)\n",
" best_models_reg[name] = grid_search.best_estimator_\n",
" print(f'Best parameters for {name}: {grid_search.best_params_}')\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: Linear Regression\n",
"Model: Random Forest Regressor\n",
"Model: Gradient Boosting Regressor\n",
"{'Linear Regression': {'pipeline': Pipeline(steps=[('scaler', StandardScaler()), ('model', LinearRegression())]), 'preds_train': array([29., 53., 82., 55., 62., 42., 21., 75., 52., 32., 36., 73., 37.,\n",
" 46., 72., 27., 71., 83., 55., 22., 65., 47., 57., 78., 91., 60.,\n",
" 72., 67., 61., 88., 60., 88., 91., 52., 84., 93., 44., 58., 90.,\n",
" 94., 59., 92., 84., 64., 51., 93., 79., 81., 72., 78., 74., 27.,\n",
" 92., 46., 72., 78., 96., 95., 77., 49., 87., 90., 88., 88., 65.,\n",
" 42., 74., 87., 41., 77., 92., 69., 94., 49., 85., 79., 73., 43.,\n",
" 55., 73., 51., 79., 52., 63., 60., 68., 70., 39., 59., 79., 27.,\n",
" 57., 78., 80., 19., 69., 68., 39., 60., 67., 68., 87., 81., 62.,\n",
" 73., 23., 69., 67., 47., 45., 90., 82., 87., 70., 77., 67., 66.,\n",
" 84., 64., 54., 46., 75., 84., 36., 72., 42., 52., 48., 55., 89.,\n",
" 64., 80., 28., 52., 81., 63., 74., 68., 53., 66., 36., 80., 75.,\n",
" 78., 34., 68., 71., 78., 83., 93., 68., 63., 79., 51., 68., 65.,\n",
" 87., 66., 52., 63., 88., 83., 85., 39., 78., 80., 65., 71., 76.,\n",
" 84., 51., 86., 65., 66., 80., 56., 67., 23., 49., 89., 69., 92.,\n",
" 63., 83., 62., 82., 34., 73., 59., 79., 60., 44., 53., 75., 71.,\n",
" 64., 60., 62., 85., 95., 90., 56., 94., 99., 51., 59., 88., 21.,\n",
" 65., 97., 67., 42., 93., 54., 56., 56., 68., 95., 55., 69., 65.,\n",
" 59., 56., 72., 28., 75., 56., 65., 74., 75., 80., 82.]), 'preds_test': array([56. , 22. , 87. , 88. , 25. ,\n",
" 54. , 87. , 85. , 60. , 88. ,\n",
" 42. , 72. , 69. , 82. , 81. ,\n",
" 48. , 94. , 56. , 65. , 86. ,\n",
" 74. , 62. , 99. , 66. , 74. ,\n",
" 59. , 60. , 64. , 64. , 62. ,\n",
" 71. , 72. , 77. , 85. , 81. ,\n",
" 81. , 55.84443686, 40. , 69. , 66. ,\n",
" 95. , 40. , 81. , 75. , 91. ,\n",
" 82. , 76. , 66. , 54. , 59. ,\n",
" 80. , 45. , 44. , 92. , 67. ,\n",
" 86. , 89. , 89. , 53. ]), 'MSE_train': np.float64(7.202572085669638e-26), 'MSE_test': np.float64(0.0004101676300062632), 'R2_train': 1.0, 'R2_test': 0.9999986319859576, 'MAE_train': np.float64(2.193188159818137e-13), 'MAE_test': np.float64(0.0026366633706367913)}, 'Random Forest Regressor': {'pipeline': Pipeline(steps=[('scaler', StandardScaler()),\n",
" ('model',\n",
" RandomForestRegressor(n_estimators=200, random_state=42))]), 'preds_train': array([32.415, 54.05 , 80.21 , 53.42 , 61.825, 48.38 , 31.21 , 75.14 ,\n",
" 53.66 , 44.475, 37.86 , 69.925, 40.7 , 49.45 , 70.985, 31.605,\n",
" 71.845, 77.67 , 55.85 , 23.27 , 64.665, 51.965, 58.55 , 67.885,\n",
" 91.005, 60.08 , 72.905, 68.86 , 63.055, 87.615, 58.84 , 87.08 ,\n",
" 90.725, 52.17 , 80.965, 91.69 , 48.03 , 58.71 , 89.26 , 92.89 ,\n",
" 59.015, 91.75 , 82.57 , 63.895, 55.675, 92.27 , 79.905, 78.265,\n",
" 72.795, 76.885, 73.87 , 33.285, 91.75 , 47.07 , 72.47 , 76.92 ,\n",
" 94.385, 90.31 , 74.365, 50.7 , 85.73 , 83.985, 87.175, 86.815,\n",
" 62.32 , 46.985, 67.69 , 88.02 , 41.14 , 76.32 , 87.4 , 66.825,\n",
" 92.305, 51.86 , 83.51 , 80.005, 70.49 , 44.39 , 56.58 , 73.695,\n",
" 52.235, 79.01 , 50.495, 65.565, 62.43 , 66.77 , 69.69 , 40.495,\n",
" 63.64 , 79.755, 28.875, 59.86 , 79.155, 81.925, 33.975, 69.73 ,\n",
" 71.19 , 47.59 , 54.73 , 67.71 , 69.41 , 84.08 , 80.425, 60.615,\n",
" 73.29 , 24.925, 68.7 , 66.365, 46.64 , 50.88 , 89.61 , 78.34 ,\n",
" 85.715, 67.64 , 79.165, 64.275, 65.885, 81.285, 64.4 , 57.835,\n",
" 49.395, 75.095, 84.47 , 40.765, 62.5 , 44.53 , 56.76 , 47.27 ,\n",
" 53.965, 87.91 , 64.765, 71.42 , 28.32 , 55.25 , 80.665, 60.975,\n",
" 73.21 , 67.89 , 54.7 , 67.895, 44.685, 79.79 , 75.025, 76.505,\n",
" 38.565, 65.68 , 72.485, 77.19 , 84.71 , 91.46 , 67.54 , 61.705,\n",
" 77.92 , 53.51 , 69.705, 66.27 , 87.135, 65.785, 57.23 , 65.945,\n",
" 88.88 , 81.18 , 82.655, 40.31 , 77.985, 80.435, 62.81 , 73.05 ,\n",
" 75.88 , 82.215, 58.5 , 81.83 , 61.155, 65.455, 76.965, 57.17 ,\n",
" 66.44 , 34.205, 48.505, 87.855, 64.585, 91.78 , 63.915, 77.265,\n",
" 64.48 , 81.42 , 40.195, 71.515, 57.38 , 74.945, 57.52 , 47.795,\n",
" 51.94 , 74.025, 68.09 , 64.38 , 62.535, 59.58 , 80.365, 93.635,\n",
" 89.98 , 56.94 , 92.79 , 97.075, 53.34 , 58.425, 87.805, 24.685,\n",
" 64.205, 94.825, 65.02 , 43.075, 91.215, 56.39 , 56.38 , 57.195,\n",
" 68.27 , 90.75 , 56.04 , 68.77 , 65.135, 58.49 , 55.54 , 73.21 ,\n",
" 40.095, 75.28 , 55.98 , 63.535, 74.15 , 74.775, 76.405, 78.46 ]), 'preds_test': array([56.47 , 42.345, 87.025, 86.56 , 40.42 , 49.335, 86.02 , 81.75 ,\n",
" 62.32 , 90.2 , 63.72 , 74.3 , 67.43 , 58.9 , 83.06 , 46.655,\n",
" 92.365, 32.505, 71.02 , 89.43 , 63.06 , 63.645, 92.385, 53.625,\n",
" 71.25 , 68.73 , 66.38 , 70.14 , 62.755, 65.02 , 72.21 , 73.205,\n",
" 74.06 , 87.985, 83.44 , 78.265, 53.98 , 52.355, 61.145, 69.6 ,\n",
" 89.645, 55.83 , 77.695, 59.03 , 89.61 , 83.235, 70.58 , 71.92 ,\n",
" 69.175, 58.48 , 75.345, 59.55 , 52.395, 85.715, 65.425, 87.95 ,\n",
" 83.12 , 87.76 , 55.63 ]), 'MSE_train': np.float64(10.585386853448275), 'MSE_test': np.float64(73.41657415254235), 'R2_train': 0.9680762587778711, 'R2_test': 0.7551369317321679, 'MAE_train': np.float64(2.189698275862069), 'MAE_test': np.float64(6.021271186440678)}, 'Gradient Boosting Regressor': {'pipeline': Pipeline(steps=[('scaler', StandardScaler()),\n",
" ('model',\n",
" GradientBoostingRegressor(max_depth=5, n_estimators=300,\n",
" random_state=42))]), 'preds_train': array([28.72956041, 53.12127389, 82.08536004, 55.09719521, 61.75388192,\n",
" 41.97235916, 21.14883789, 74.54323397, 52.25364062, 32.10924489,\n",
" 36.08384782, 72.70845527, 37.11384401, 46.04284373, 72.05464788,\n",
" 27.19660712, 71.03059415, 82.78900252, 54.82581543, 22.08471572,\n",
" 64.95177947, 46.91864608, 57.0614107 , 77.77389579, 90.91849183,\n",
" 60.00063443, 72.32250587, 67.19480682, 61.21037107, 87.9493728 ,\n",
" 59.76167757, 87.9242189 , 91.11672076, 51.981465 , 83.97286576,\n",
" 92.99512162, 44.2265744 , 57.97309623, 89.8580269 , 93.93278779,\n",
" 58.9790766 , 92.10213846, 83.92831871, 64.01048318, 50.85907853,\n",
" 93.03022066, 79.35509757, 80.97385883, 72.26475898, 78.0822317 ,\n",
" 73.74605417, 26.94997048, 91.93353737, 46.11073777, 71.86943063,\n",
" 78.22666513, 95.97811062, 94.90309836, 76.79483994, 49.04234743,\n",
" 87.10113854, 90.00164369, 88.15604432, 88.06107202, 64.86758165,\n",
" 42.0662194 , 73.86206285, 87.06076311, 40.77837315, 76.93677631,\n",
" 91.80841172, 68.86009114, 93.99977552, 49.01611104, 85.1215977 ,\n",
" 79.20236795, 72.81006079, 42.88133804, 55.07471142, 73.08367579,\n",
" 51.10101262, 79.26235085, 51.93996986, 63.1400842 , 60.16031868,\n",
" 67.74505892, 69.92474149, 39.10249238, 59.19318532, 79.00162184,\n",
" 27.13287068, 57.14727171, 78.13131855, 80.04141944, 19.21578015,\n",
" 69.08228133, 68.34019354, 39.23243336, 59.69048347, 66.68312364,\n",
" 67.90455008, 86.87414348, 80.96263263, 62.01895029, 72.84596009,\n",
" 22.9688195 , 69.07247695, 67.07765118, 46.8756752 , 45.26635382,\n",
" 89.94753457, 82.01807007, 86.98465799, 70.20029179, 77.11751998,\n",
" 66.85572827, 65.88490085, 83.94892282, 64.05271048, 54.31302531,\n",
" 46.21858535, 74.71421415, 84.18057805, 35.93101794, 71.74544023,\n",
" 41.87496037, 51.97305771, 47.95325267, 54.99051089, 88.93282021,\n",
" 63.9305233 , 79.93395234, 27.90266723, 51.99225237, 80.96691129,\n",
" 62.96957561, 74.2341931 , 68.04553945, 53.25614425, 66.06230994,\n",
" 36.07177583, 79.73194835, 74.81165201, 77.92313651, 34.03449993,\n",
" 67.979729 , 71.21728118, 77.76422978, 83.48678921, 92.88231457,\n",
" 68.07190528, 63.0157021 , 78.64975048, 51.06214065, 68.08719803,\n",
" 65.07261616, 87.04733437, 66.02343628, 52.0736837 , 63.22318294,\n",
" 88.04851864, 82.82539568, 84.97439652, 39.24799156, 78.25738675,\n",
" 79.93756933, 64.4750149 , 71.56468737, 76.18401232, 84.06330088,\n",
" 51.02264414, 85.98802018, 64.92800866, 66.16320124, 79.94939849,\n",
" 56.07374628, 66.98345294, 23.02540478, 49.17449175, 88.88588133,\n",
" 68.9329792 , 92.03345878, 63.07777892, 82.73557105, 61.78437332,\n",
" 81.8909867 , 34.21616731, 72.87348414, 58.98687689, 78.8140383 ,\n",
" 59.9574234 , 44.19210735, 52.71369582, 75.20218936, 70.59615384,\n",
" 63.54886587, 60.49279846, 61.78645898, 84.87971032, 94.81801802,\n",
" 89.90842136, 55.66192951, 93.90927911, 98.9415322 , 51.01506961,\n",
" 58.55722323, 87.77450912, 20.9321725 , 64.89912387, 97.0158939 ,\n",
" 67.06399678, 41.91876756, 92.91632536, 54.03711532, 56.10247109,\n",
" 55.84819722, 67.98653048, 95.00209989, 54.94376476, 69.02145146,\n",
" 65.17895584, 59.02771118, 55.92396986, 72.8440164 , 28.29625663,\n",
" 75.01157336, 56.0700562 , 64.97176071, 74.08213306, 74.93307889,\n",
" 79.9003625 , 81.82188841]), 'preds_test': array([66.503057 , 38.21021974, 89.04970956, 87.89189709, 36.73554665,\n",
" 45.73843308, 87.79291212, 82.13464953, 60.57314255, 91.22929864,\n",
" 62.23036581, 71.58664491, 66.17112665, 62.51658214, 81.60921548,\n",
" 38.83018398, 91.30064235, 24.90175483, 69.16155336, 87.4365223 ,\n",
" 71.56622022, 63.57230002, 93.97558163, 49.45397887, 68.85601209,\n",
" 68.60673528, 63.94743518, 68.42632232, 65.30704897, 66.74142159,\n",
" 68.75949485, 74.90532442, 73.25421167, 89.6482385 , 82.66649342,\n",
" 78.86658868, 56.09338908, 61.2786305 , 60.68340277, 71.36372731,\n",
" 90.85782508, 52.24020316, 83.95183498, 62.00353481, 89.95327108,\n",
" 86.00387125, 71.50207355, 76.51105405, 67.41310326, 58.59170399,\n",
" 76.96828297, 62.60133656, 46.93230456, 87.0082761 , 68.74473539,\n",
" 88.07943744, 83.14111532, 87.20969454, 53.47940315]), 'MSE_train': np.float64(0.03170555378857528), 'MSE_test': np.float64(78.66069437097792), 'R2_train': 0.999904381397821, 'R2_test': 0.7376464483927592, 'MAE_train': np.float64(0.13311180849830512), 'MAE_test': np.float64(6.141062895444746)}}\n"
]
}
],
"source": [
"# Обучение моделей и оценка качества\n",
"results_reg = {}\n",
"\n",
"for model_name in best_models_reg.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model_pipeline = best_models_reg[model_name]\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train_reg)\n",
" y_test_predict = model_pipeline.predict(X_test_reg)\n",
"\n",
" results_reg[model_name] = {\n",
" \"pipeline\": model_pipeline,\n",
" \"preds_train\": y_train_predict,\n",
" \"preds_test\": y_test_predict,\n",
" \"MSE_train\": mean_squared_error(y_train_reg, y_train_predict),\n",
" \"MSE_test\": mean_squared_error(y_test_reg, y_test_predict),\n",
" \"R2_train\": r2_score(y_train_reg, y_train_predict),\n",
" \"R2_test\": r2_score(y_test_reg, y_test_predict),\n",
" \"MAE_train\": mean_absolute_error(y_train_reg, y_train_predict),\n",
" \"MAE_test\": mean_absolute_error(y_test_reg, y_test_predict),\n",
" }\n",
"\n",
"# Теперь результаты каждой модели находятся в results_reg\n",
"print(results_reg)\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"data2 = data.drop(['Short description', 'Manner of death', 'Gender', 'Country', 'Occupation'], axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['Birth year', 'Death year'], dtype='object')\n",
"Best parameters for Logistic Regression: {'model__C': 10, 'model__solver': 'lbfgs'}\n",
"Best parameters for Random Forest Classifier: {'model__max_depth': 30, 'model__n_estimators': 200}\n",
"Best parameters for Gradient Boosting Classifier: {'model__learning_rate': 0.1, 'model__max_depth': 7, 'model__n_estimators': 200}\n",
"Model: Logistic Regression\n",
"Model: Random Forest Classifier\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\alexk\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\alexk\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: Gradient Boosting Classifier\n"
]
}
],
"source": [
"# Создание возрастных групп\n",
"bins = [0, 18, 30, 50, 70, 100]\n",
"labels = ['0-18', '19-30', '31-50', '51-70', '71+']\n",
"data['Age Group'] = pd.cut(data['Age of death'], bins=bins, labels=labels)\n",
"data2['Age Group'] = pd.cut(data2['Age of death'], bins=bins, labels=labels)\n",
"\n",
"# Выбор признаков и целевой переменной для классификации\n",
"X_class = data2.drop(['Id', 'Name', 'Age of death', 'Age Group'], axis=1)\n",
"y_class = data['Age Group'] \n",
"print(X_class.columns)\n",
"# Разделение данных\n",
"X_train_class, X_test_class, y_train_class, y_test_class = train_test_split(X_class, y_class, test_size=0.2, random_state=42)\n",
"\n",
"# Выбор моделей для классификации\n",
"models_class = {\n",
" 'Logistic Regression': LogisticRegression(random_state=42, max_iter=5000, solver='liblinear'),\n",
" 'Random Forest Classifier': RandomForestClassifier(random_state=42),\n",
" 'Gradient Boosting Classifier': GradientBoostingClassifier(random_state=42)\n",
"}\n",
"\n",
"# Создание конвейера для классификации\n",
"pipelines_class = {}\n",
"for name, model in models_class.items():\n",
" pipelines_class[name] = Pipeline([\n",
" ('scaler', StandardScaler()),\n",
" ('model', model)\n",
" ])\n",
"\n",
"# Определение сетки гиперпараметров для классификации\n",
"\n",
"param_grids_class = {\n",
" 'Logistic Regression': {\n",
" 'model__C': [0.1, 1, 10],\n",
" 'model__solver': ['lbfgs', 'liblinear']\n",
" },\n",
" 'Random Forest Classifier': {\n",
" 'model__n_estimators': [100, 200, 300],\n",
" 'model__max_depth': [None, 10, 20, 30]\n",
" },\n",
" 'Gradient Boosting Classifier': {\n",
" 'model__n_estimators': [100, 200, 300],\n",
" 'model__learning_rate': [0.01, 0.1, 0.2],\n",
" 'model__max_depth': [3, 5, 7]\n",
" }\n",
"}\n",
"# Убрал определение параметров поскольку уже был предподсчет данных, но вылетела ошибка. Сохранил лучшие параметры\n",
"\n",
"param_grids_class = {\n",
" 'Logistic Regression': {\n",
" 'model__C': [10],\n",
" 'model__solver': ['lbfgs']\n",
" },\n",
" 'Random Forest Classifier': {\n",
" 'model__n_estimators': [200],\n",
" 'model__max_depth': [ 30]\n",
" },\n",
" 'Gradient Boosting Classifier': {\n",
" 'model__n_estimators': [200],\n",
" 'model__learning_rate': [0.1],\n",
" 'model__max_depth': [7]\n",
" }\n",
"}\n",
"\n",
"# Настройка гиперпараметров для классификации\n",
"best_models_class = {}\n",
"for name, pipeline in pipelines_class.items():\n",
" grid_search = GridSearchCV(pipeline, param_grids_class[name], cv=5, scoring='accuracy', n_jobs=-1)\n",
" grid_search.fit(X_train_class, y_train_class)\n",
" best_models_class[name] = {\"model\": grid_search.best_estimator_}\n",
" print(f'Best parameters for {name}: {grid_search.best_params_}')\n",
"\n",
"# Обучение моделей и оценка качества\n",
"for model_name in best_models_class.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = best_models_class[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"scaler\", StandardScaler()), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train_class, y_train_class)\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train_class)\n",
" y_test_probs = model_pipeline.predict_proba(X_test_class)\n",
" y_test_predict = model_pipeline.predict(X_test_class)\n",
"\n",
" best_models_class[model_name][\"pipeline\"] = model_pipeline\n",
" best_models_class[model_name][\"probs\"] = y_test_probs\n",
" best_models_class[model_name][\"preds\"] = y_test_predict\n",
"\n",
" best_models_class[model_name][\"Precision_train\"] = precision_score(y_train_class, y_train_predict, average='weighted')\n",
" best_models_class[model_name][\"Precision_test\"] = precision_score(y_test_class, y_test_predict, average='weighted')\n",
" best_models_class[model_name][\"Recall_train\"] = recall_score(y_train_class, y_train_predict, average='weighted')\n",
" best_models_class[model_name][\"Recall_test\"] = recall_score(y_test_class, y_test_predict, average='weighted')\n",
" best_models_class[model_name][\"Accuracy_train\"] = accuracy_score(y_train_class, y_train_predict)\n",
" best_models_class[model_name][\"Accuracy_test\"] = accuracy_score(y_test_class, y_test_predict)\n",
" best_models_class[model_name][\"ROC_AUC_test\"] = roc_auc_score(y_test_class, y_test_probs, multi_class='ovr')\n",
" best_models_class[model_name][\"F1_train\"] = f1_score(y_train_class, y_train_predict, average='weighted')\n",
" best_models_class[model_name][\"F1_test\"] = f1_score(y_test_class, y_test_predict, average='weighted')\n",
" best_models_class[model_name][\"MCC_test\"] = matthews_corrcoef(y_test_class, y_test_predict)\n",
" best_models_class[model_name][\"Cohen_kappa_test\"] = cohen_kappa_score(y_test_class, y_test_predict)\n",
" best_models_class[model_name][\"Confusion_matrix\"] = confusion_matrix(y_test_class, y_test_predict)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbcAAAQ9CAYAAADahAPyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAADyfElEQVR4nOzdeVhU1f8H8PewDfsmm8giiuK+hGm4mySSX5fUcsvANQv3ncx9obRcyj0NsjSzci9R3LXQ1ERzIxRUVMAFWVWWmfP7gx+TV1BnEJhheL+e5z6Pc+fMuZ97GflwlnuuTAghQEREpEcMtB0AERFRaWNyIyIivcPkRkREeofJjYiI9A6TGxER6R0mNyIi0jtMbkREpHeY3IiISO8wuRERkd5hciOd0L59e7Rv377U6qtevTqCg4NLrT4CZDIZZs2ape0wiNTC5EYSERERkMlkOH36tLZDeak///wTs2bNQlpaWpkep3r16pDJZKrNwsICzZs3x4YNG8r0uERUckbaDoAIAPbt26fxZ/7880/Mnj0bwcHBsLW1lbwXGxsLA4PS+9utSZMmmDBhAgAgKSkJ69atQ1BQEHJycjBs2LBSO44ue/z4MYyM+CuDKgZ+U0knmJiYlGp9crm8VOurVq0a3n//fdXr4OBg1KhRA0uWLCn35JadnQ0LC4tyPSYAmJqalvsxiUqK3ZJUImfPnkVgYCCsra1haWmJjh074sSJE0XKnT9/Hu3atYOZmRnc3Nwwb948hIeHQyaT4fr166pyxY25ff3116hfvz7Mzc1hZ2eHZs2aYdOmTQCAWbNmYdKkSQAALy8vVZdhYZ3FjbmlpaVh3LhxqF69OuRyOdzc3PDBBx/g/v37Gp+/o6Mj6tSpg2vXrkn2K5VKLF26FPXr14epqSmcnZ3x4Ycf4uHDh0XKzZo1C66urjA3N0eHDh1w6dKlInEXdhMfOXIEH3/8MZycnODm5qZ6f8+ePWjTpg0sLCxgZWWFLl264OLFi5JjJScnY9CgQXBzc4NcLkfVqlXRvXt3yfU/ffo0AgIC4ODgADMzM3h5eWHw4MGSeoobc1Pne1B4Dn/88QfGjx8PR0dHWFhY4J133sG9e/fUveREGmHLjTR28eJFtGnTBtbW1pg8eTKMjY2xZs0atG/fHkeOHEGLFi0AALdv30aHDh0gk8kQGhoKCwsLrFu3Tq1W1TfffIPRo0ejd+/eGDNmDJ48eYLz58/j5MmT6N+/P3r27Il///0XP/74I5YsWQIHBwcABUmnOFlZWWjTpg0uX76MwYMH47XXXsP9+/exc+dO3Lp1S/V5deXn5+PWrVuws7OT7P/www8RERGBQYMGYfTo0UhISMDy5ctx9uxZ/PHHHzA2NgYAhIaGYuHChejatSsCAgJw7tw5BAQE4MmTJ8Ue7+OPP4ajoyNmzJiB7OxsAMD333+PoKAgBAQE4PPPP8ejR4+watUqtG7dGmfPnkX16tUBAL169cLFixcxatQoVK9eHXfv3kVUVBRu3rypet2pUyc4Ojpi6tSpsLW1xfXr17F169YXXgN1vweFRo0aBTs7O8ycORPXr1/H0qVLMXLkSPz0008aXXsitQiip4SHhwsA4tSpU88t06NHD2FiYiKuXbum2nfnzh1hZWUl2rZtq9o3atQoIZPJxNmzZ1X7Hjx4IOzt7QUAkZCQoNrfrl070a5dO9Xr7t27i/r1678w1kWLFhWpp5Cnp6cICgpSvZ4xY4YAILZu3VqkrFKpfOFxPD09RadOncS9e/fEvXv3xD///CMGDhwoAIiQkBBVuWPHjgkAYuPGjZLPR0ZGSvYnJycLIyMj0aNHD0m5WbNmCQCSuAt/Hq1btxb5+fmq/ZmZmcLW1lYMGzZMUkdycrKwsbFR7X/48KEAIBYtWvTc89u2bdtLf+ZCCAFAzJw5U/Va3e9B4Tn4+/tLrvW4ceOEoaGhSEtLe+FxiUqC3ZKkEYVCgX379qFHjx6oUaOGan/VqlXRv39/HD9+HBkZGQCAyMhI+Pn5oUmTJqpy9vb2GDBgwEuPY2tri1u3buHUqVOlEvevv/6Kxo0b45133inynkwme+nn9+3bB0dHRzg6OqJhw4b4/vvvMWjQICxatEhV5ueff4aNjQ3eeust3L9/X7X5+vrC0tIShw4dAgAcOHAA+fn5+PjjjyXHGDVq1HOPP2zYMBgaGqpeR0VFIS0tDf369ZMcy9DQEC1atFAdy8zMDCYmJjh8+HCRrtFChZNxdu/ejby8vJdeC0Cz70Gh4cOHS651mzZtoFAocOPGDbWOSaQJJjfSyL179/Do0SP4+PgUea9u3bpQKpVITEwEANy4cQPe3t5FyhW371lTpkyBpaUlmjdvjlq1aiEkJAR//PFHieO+du0aGjRoUOLPt2jRAlFRUYiMjMQXX3wBW1tbPHz4UDIRJi4uDunp6XByclIlwsItKysLd+/eBQDVL/Nnr4O9vX2Rbs5CXl5ektdxcXEAgDfffLPIsfbt26c6llwux+eff449e/bA2dkZbdu2xcKFC5GcnKyqq127dujVqxdmz54NBwcHdO/eHeHh4cjJyXnu9dDke1DIw8ND8rrwXJ+XdIleBcfcSCfVrVsXsbGx2L17NyIjI/Hrr79i5cqVmDFjBmbPnl3u8Tg4OMDf3x8AEBAQgDp16uB///sfli1bhvHjxwMomCTi5OSEjRs3FlvH88YD1WFmZiZ5rVQqARSMu7m4uBQp//SU/bFjx6Jr167Yvn079u7di+nTpyMsLAwHDx5E06ZNIZPJ8Msvv+DEiRPYtWsX9u7di8GDB+PLL7/EiRMnYGlpWeK4n/Z0y/NpQohSqZ/oaUxupBFHR0eYm5sjNja2yHtXrlyBgYEB3N3dAQCenp64evVqkXLF7SuOhYUF+vTpgz59+iA3Nxc9e/bE/PnzERoaClNTU7W6EwvVrFkTFy5cULv8y3Tp0gXt2rXDggUL8OGHH8LCwgI1a9bE/v370apVqyLJ6Gmenp4ACq7D0y2yBw8eqN2KqVmzJgDAyclJlXRfVn7ChAmYMGEC4uLi0KRJE3z55Zf44YcfVGXeeOMNvPHGG5g/fz42bdqEAQMGYPPmzRg6dGiR+jT5HhBpA7slSSOGhobo1KkTduzYIZlKnpKSgk2bNqF169awtrYGUNDCiY6ORkxMjKpcamrqc1s2T3vw4IHktYmJCerVqwchhGpcqPBeL3VWKOnVqxfOnTuHbdu2FXmvpC2HKVOm4MGDB/jmm28AAO+99x4UCgXmzp1bpGx+fr4qzo4dO8LIyAirVq2SlFm+fLnaxw4ICIC1tTUWLFhQ7DhZ4RT7R48eFZmBWbNmTVhZWam6HR8+fFjkGhSOkz6va1KT7wGRNrDlRsX69ttvERkZWWT/mDFjMG/ePERFRaF169b4+OOPYWRkhDVr1iAnJwcLFy5UlZ08eTJ++OEHvPXWWxg1apTqVgAPDw+kpqa+sOXVqVMnuLi4oFWrVnB2dsbly5exfPlydOnSBVZWVgAAX19fAMC0adPQt29fGBsbo2vXrsXe4Dxp0iT88ssvePfddzF48GD4+voiNTUVO3fuxOrVq9G4cWONr1FgYCAaNGiAxYsXIyQkBO3atcOHH36IsLAwxMTEoFOnTjA2NkZcXBx+/vlnLFu2DL1794azszPGjBmDL7/8Et26dUPnzp1x7tw57NmzBw4ODmq1SK2trbFq1SoMHDgQr732Gvr27QtHR0fcvHkTv/32G1q1aoXly5fj33//RceOHfHee++hXr16MDIywrZt25CSkoK+ffsCAL777jusXLkS77zzDmrWrInMzEx88803sLa2xttvv/3cGNT9HhBphXYna5KuKZy2/bwtMTFRCCHE33//LQICAoSlpaUwNzcXHTp0EH/++WeR+s6ePSvatGkj5HK5cHNzE2FhYeKrr74SAERycrKq3LO3AqxZs0a0bdtWVKlSRcjlclGzZk0xadIkkZ6eLql/7ty5olq1asLAwEByW8CztwIIUXAbwsiRI0W1atWEiYmJcHNzE0FBQeL+/fsvvCaenp6iS5cuxb4XEREhAIj
"text/plain": [
"<Figure size 1200x1000 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"num_models = len(best_models_class)\n",
"fig, ax = plt.subplots(num_models, 1, figsize=(12, 10), sharex=False, sharey=False)\n",
"\n",
"for index, key in enumerate(best_models_class.keys()):\n",
" c_matrix = best_models_class[key][\"Confusion_matrix\"]\n",
" \n",
" # Получаем метки классов из матрицы ошибок\n",
" num_classes = c_matrix.shape[0]\n",
" actual_labels = [\"0-18\", \"19-30\", \"31-50\", \"51-70\", \"71+\"][:num_classes]\n",
" \n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=actual_labels\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA+0AAAQ9CAYAAAAs3qyyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxUVRvA8d+sDOsAyiKIgoi7uaHkkmiSVGaZ5pa9Ki1W7lKW9qZmpVSaWrmgldqiaWaZLZqGaaW4p6XmvqYCirINyzAz9/2D16kRMFAQxOf7+cxH5sy5d57L4Jx55p77HJWiKApCCCGEEEIIIYSodNQVHYAQQgghhBBCCCGKJkm7EEIIIYQQQghRSUnSLoQQQgghhBBCVFKStAshhBBCCCGEEJWUJO1CCCGEEEIIIUQlJUm7EEIIIYQQQghRSUnSLoQQQgghhBBCVFKStAshhBBCCCGEEJWUJO1CCCGEEEIIIUQlJUm7KJGTJ0+iUqlYvHhxRYcibsDgwYMJDg6u6DCqDJVKxSuvvFLRYVS4Tp060alTJ/v9yvh+cXWMQlQ2lfH/jSg9GWcLvPLKK6hUqhL1vRljqYwBJSdjeuUkSbtg8eLFqFQqdu7cWdGhlJsrg8eVm06nIzg4mJEjR5KWllbR4Qlg7ty5qFQqIiIirnsf586d45VXXmHPnj1lF1glt3HjxkJ/23Xq1GHgwIEcP368osMrlS1btvDKK6/I/0lR5cg4m1bR4d22rvztXblptVoCAwMZPHgwZ8+erejwxFVkTBfF0VZ0AOLWULt2bXJyctDpdBUdyg2ZN28ebm5umEwmEhISeO+999i9eze//vprRYd2U7z//vvYbLaKDqNIS5YsITg4mO3bt3P06FHq1q1b6n2cO3eOyZMnExwcTPPmzcs+yEps5MiRtG7dmvz8fHbv3s2CBQv47rvv+OOPPwgICLipsVzv+8WWLVuYPHkygwcPxtPTs3yCE6KSknG2aqis4+yrr75KSEgIubm5bN26lcWLF/Prr7+yb98+DAZDmT/fyy+/zLhx48p8v7cLGdPF1eRMuygRlUqFwWBAo9FUdCjFys7O/tc+jzzyCI899hhPP/00n3/+OX379mXz5s1s3779JkT4N5vNRm5u7k19TgCdToeTk9NNf95/c+LECbZs2cKMGTPw8fFhyZIlFR3SLeeuu+7iscceIyYmhvfee4/p06dz6dIlPvroo2K3MZlM5RLLrfB+IURlcyv8v5Fx9t9V1nH2vvvu47HHHuPJJ5/kgw8+4Pnnn+fYsWOsXr26XJ5Pq9WWy5cBtwsZ08XVJGkXJVLU9SyDBw/Gzc2Ns2fP0qNHD9zc3PDx8eH555/HarU6bG+z2Zg1axaNGzfGYDDg5+fH008/zeXLlx36ff3113Tr1o2AgACcnJwIDQ3ltddeK7S/Tp060aRJE3bt2kXHjh1xcXHhpZdeKvVx3XXXXQAcO3bMoX3btm3ce++9GI1GXFxciIyMZPPmzYW237hxI+Hh4RgMBkJDQ5k/f36R13GpVCqGDx/OkiVLaNy4MU5OTqxduxaAs2fP8vjjj+Pn54eTkxONGzdm4cKFhZ7rvffeo3Hjxri4uODl5UV4eDhLly61P56Zmcno0aMJDg7GyckJX19f7rnnHnbv3m3vU9S1diaTieeee46goCCcnJyoX78+06dPR1GUIo9h1apVNGnSxB7rleP4p4MHD3L69OmifuVFWrJkCV5eXnTr1o1HHnmk2KQ9LS2NMWPG2I+xZs2aDBw4kIsXL7Jx40Zat24NQExMjH1q2ZW/2eDgYAYPHlxon1dfF2U2m5k4cSKtWrXCaDTi6urKXXfdxU8//VTi47kiOTkZrVbL5MmTCz126NAhVCoVs2fPBiA/P5/JkycTFhaGwWCgWrVqdOjQgfXr15f6eQHuvvtuoOALEfh76uqBAwd49NFH8fLyokOHDvb+n376Ka1atcLZ2Rlvb2/69evHmTNnCu13wYIFhIaG4uzsTJs2bfjll18K9Snu+reDBw/Sp08ffHx8cHZ2pn79+vz3v/+1xzd27FgAQkJC7K/fyZMnyyVGISobGWdlnC3PcfZqxb0uBw8e5JFHHsHb2xuDwUB4eHihxL4k41VRr1FeXh5jxozBx8cHd3d3HnzwQf76669CsRVXF6CofS5atIi7774bX19fnJycaNSoEfPmzSvR7+DfXu+ryZi+2KFdxvSbS6bHixtitVqJjo4mIiKC6dOn8+OPP/L2228TGhrKs88+a+/39NNPs3jxYmJiYhg5ciQnTpxg9uzZ/Pbbb2zevNk+5Wbx4sW4ubkRGxuLm5sbGzZsYOLEiWRkZDBt2jSH505NTeW+++6jX79+PPbYY/j5+ZU6/itvHl5eXva2DRs2cN9999GqVSsmTZqEWq22Dwq//PILbdq0AeC3337j3nvvpUaNGkyePBmr1cqrr76Kj49Pkc+1YcMGPv/8c4YPH0716tUJDg4mOTmZO++80z5Q+/j4sGbNGp544gkyMjIYPXo0UDDdbuTIkTzyyCOMGjWK3Nxcfv/9d7Zt28ajjz4KwDPPPMMXX3zB8OHDadSoEampqfz666/8+eeftGzZssiYFEXhwQcf5KeffuKJJ56gefPm/PDDD4wdO5azZ88yc+ZMh/6//vorX375JUOHDsXd3Z13332XXr16cfr0aapVq2bv17BhQyIjI9m4cWOJXoclS5bQs2dP9Ho9/fv3Z968eezYscOehANkZWVx11138eeff/L444/TsmVLLl68yOrVq/nrr79o2LAhr776KhMnTmTIkCH2DyTt2rUrUQxXZGRk8MEHH9C/f3+eeuopMjMz+fDDD4mOjmb79u2lmnbv5+dHZGQkn3/+OZMmTXJ4bPny5Wg0Gnr37g0UDHBxcXE8+eSTtGnThoyMDHbu3Mnu3bu55557SnUM8PcHsX++LgC9e/cmLCyMqVOn2j8wTpkyhQkTJtCnTx+efPJJLly4wHvvvUfHjh357bff7NPaPvzwQ55++mnatWvH6NGjOX78OA8++CDe3t4EBQVdM57ff/+du+66C51Ox5AhQwgODubYsWN88803TJkyhZ49e3L48GE+++wzZs6cSfXq1QHs/59uRoxCVEYyzso4Wxbj7NWKel32799P+/btCQwMZNy4cbi6uvL555/To0cPVq5cycMPPwxc/3j15JNP8umnn/Loo4/Srl07NmzYQLdu3a4r/ivmzZtH48aNefDBB9FqtXzzzTcMHToUm83GsGHDit2uJK/31WRM/5uM6RVAEbe9RYsWKYCyY8eOYvucOHFCAZRFixbZ2wYNGqQAyquvvurQt0WLFkqrVq3s93/55RcFUJYsWeLQb+3atYXas7OzCz33008/rbi4uCi5ubn2tsjISAVQ4uPjS3SMkyZNUgDl0KFDyoULF5STJ08qCxcuVJydnRUfHx/FZDIpiqIoNptNCQsLU6KjoxWbzeYQV0hIiHLPPffY27p37664uLgoZ8+etbcdOXJE0Wq1ytX/tQBFrVYr+/fvd2h/4oknlBo1aigXL150aO/Xr59iNBrtv4+HHnpIady48TWP0Wg0KsOGDbtmn0GDBim1a9e231+1apUCKK+//rpDv0ceeURRqVTK0aNHHY5Br9c7tO3du1cBlPfee6/Q8UZGRl4zlit27typAMr69esVRSl4DWrWrKmMGjXKod/EiRMVQPnyyy8L7ePKa7Vjx45Cf6dX1K5dWxk0aFCh9sjISIdYLRaLkpeX59Dn8uXLip+fn/L44487tAPKpEmTrnl88+fPVwDljz/+cGhv1KiRcvfdd9vvN2vWTOnWrds191WUn376SQGUhQsXKhcuXFDOnTunfPfdd0pwcLCiUqns/6+v/B/o37+/w/YnT55UNBqNMmXKFIf2P/74Q9FqtfZ2s9ms+Pr6Ks2bN3f4/SxYsKDQ613U+0XHjh0Vd3d35dSpUw7P88//Z9OmTVMA5cSJE+UeoxA3k4yzMs5W1Dh75W/vxx9/VC5cuKCcOXNG+eKLLxQfHx/FyclJOXPmjL1vly5dlKZNmzr8HdhsNqV
"text/plain": [
"<Figure size 1200x1000 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(3, 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"ax = ax.flatten()\n",
"\n",
"for index, (name, model) in enumerate(best_models_reg.items()):\n",
" y_pred_reg = model.predict(X_test_reg)\n",
"\n",
" # График фактических значений против предсказанных значений\n",
" ax[index * 2].scatter(y_test_reg, y_pred_reg, alpha=0.5)\n",
" ax[index * 2].plot([min(y_test_reg), max(y_test_reg)], [min(y_test_reg), max(y_test_reg)], color='red', linestyle='--')\n",
" ax[index * 2].set_xlabel('Actual Values')\n",
" ax[index * 2].set_ylabel('Predicted Values')\n",
" ax[index * 2].set_title(f'{name}: Actual vs Predicted')\n",
"\n",
" # График остатков\n",
" residuals = y_test_reg - y_pred_reg\n",
" ax[index * 2 + 1].scatter(y_pred_reg, residuals, alpha=0.5)\n",
" ax[index * 2 + 1].axhline(y=0, color='red', linestyle='--')\n",
" ax[index * 2 + 1].set_xlabel('Predicted Values')\n",
" ax[index * 2 + 1].set_ylabel('Residuals')\n",
" ax[index * 2 + 1].set_title(f'{name}: Residuals vs Predicted')\n",
"\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
2024-12-14 03:08:44 +04:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}