610 lines
366 KiB
Plaintext
Raw Normal View History

2024-12-06 20:26:56 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вариант: Экономика стран"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 369 entries, 0 to 368\n",
"Data columns (total 14 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 stock index 369 non-null object \n",
" 1 country 369 non-null object \n",
" 2 year 369 non-null float64\n",
" 3 index price 317 non-null float64\n",
" 4 log_indexprice 369 non-null float64\n",
" 5 inflationrate 326 non-null float64\n",
" 6 oil prices 369 non-null float64\n",
" 7 exchange_rate 367 non-null float64\n",
" 8 gdppercent 350 non-null float64\n",
" 9 percapitaincome 368 non-null float64\n",
" 10 unemploymentrate 348 non-null float64\n",
" 11 manufacturingoutput 278 non-null float64\n",
" 12 tradebalance 365 non-null float64\n",
" 13 USTreasury 369 non-null float64\n",
"dtypes: float64(12), object(2)\n",
"memory usage: 40.5+ KB\n"
]
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn import metrics\n",
"from imblearn.over_sampling import RandomOverSampler\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.linear_model import LinearRegression, LogisticRegression\n",
"from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, RandomForestClassifier, GradientBoostingClassifier\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"from sklearn.metrics import (\n",
" precision_score, recall_score, accuracy_score, roc_auc_score, f1_score,\n",
" matthews_corrcoef, cohen_kappa_score, confusion_matrix\n",
")\n",
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error\n",
"import numpy as np\n",
"import featuretools as ft\n",
"from sklearn.metrics import accuracy_score, classification_report\n",
"\n",
"# Функция для применения oversampling\n",
"def apply_oversampling(X, y):\n",
" oversampler = RandomOverSampler(random_state=42)\n",
" X_resampled, y_resampled = oversampler.fit_resample(X, y)\n",
" return X_resampled, y_resampled\n",
"\n",
"# Функция для применения undersampling\n",
"def apply_undersampling(X, y):\n",
" undersampler = RandomUnderSampler(random_state=42)\n",
" X_resampled, y_resampled = undersampler.fit_resample(X, y)\n",
" return X_resampled, y_resampled\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
"):\n",
" \"\"\"\n",
" Splits a Pandas dataframe into three subsets (train, val, and test)\n",
" following fractional ratios provided by the user, where each subset is\n",
" stratified by the values in a specific column (that is, each subset has\n",
" the same relative frequency of the values in the column). It performs this\n",
" splitting by running train_test_split() twice.\n",
"\n",
" Parameters\n",
" ----------\n",
" df_input : Pandas dataframe\n",
" Input dataframe to be split.\n",
" stratify_colname : str\n",
" The name of the column that will be used for stratification. Usually\n",
" this column would be for the label.\n",
" frac_train : float\n",
" frac_val : float\n",
" frac_test : float\n",
" The ratios with which the dataframe will be split into train, val, and\n",
" test data. The values should be expressed as float fractions and should\n",
" sum to 1.0.\n",
" random_state : int, None, or RandomStateInstance\n",
" Value to be passed to train_test_split().\n",
"\n",
" Returns\n",
" -------\n",
" df_train, df_val, df_test :\n",
" Dataframes containing the three splits.\n",
" \"\"\"\n",
"\n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
"\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
"\n",
" X = df_input # Contains all columns.\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
"\n",
" # Split original dataframe into train and temp dataframes.\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
"\n",
" # Split the temp dataframe into val and test dataframes.\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
"\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
"\n",
" return df_train, df_val, df_test\n",
"\n",
"\n",
"df = pd.read_csv(\"../data/Economic.csv\")\n",
"df.info()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Определение бизнес-целей\n",
"\n",
"Задача регрессии:\n",
"Предсказать цену индекса (index price) на основе других факторов, таких как инфляция, цены на нефть, обменный курс, ВВП, доход на душу населения, уровень безработицы.\n",
"\n",
"Задача классификации:\n",
"Классификация по странам, на основе цены индекса, показателей ВВП, инфляции, безработицы, цены на нефть. Это поможет определить, к какой категории относится страна и какие меры могут быть предприняты для ее развития."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Дополнение данных"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Исходный размер датасета: 369\n",
"Очищенный размер датасета: 262\n"
]
}
],
"source": [
"data = df.copy()\n",
"data_cleaned = df.dropna(subset=['index price', 'inflationrate', 'exchange_rate', 'gdppercent', 'percapitaincome', 'unemploymentrate'])\n",
"\n",
"print(f\"Исходный размер датасета: {df.shape[0]}\")\n",
"print(f\"Очищенный размер датасета: {data_cleaned.shape[0]}\")\n",
"\n",
"data1 = pd.get_dummies(data_cleaned, columns=['country'], drop_first=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Определение достижимого уровня"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['stock index', 'year', 'index price', 'log_indexprice', 'inflationrate',\n",
" 'oil prices', 'exchange_rate', 'gdppercent', 'percapitaincome',\n",
" 'unemploymentrate', 'manufacturingoutput', 'tradebalance', 'USTreasury',\n",
" 'country_France', 'country_Germany', 'country_Hong Kong',\n",
" 'country_India', 'country_Japan', 'country_Spain',\n",
" 'country_United Kingdom', 'country_United States of America'],\n",
" dtype='object')\n",
"stock index 0\n",
"year 0\n",
"index price 0\n",
"log_indexprice 0\n",
"inflationrate 0\n",
"oil prices 0\n",
"exchange_rate 0\n",
"gdppercent 0\n",
"percapitaincome 0\n",
"unemploymentrate 0\n",
"manufacturingoutput 41\n",
"tradebalance 2\n",
"USTreasury 0\n",
"country_France 0\n",
"country_Germany 0\n",
"country_Hong Kong 0\n",
"country_India 0\n",
"country_Japan 0\n",
"country_Spain 0\n",
"country_United Kingdom 0\n",
"country_United States of America 0\n",
"dtype: int64\n"
]
}
],
"source": [
"print(data1.columns)\n",
"print(data1.isnull().sum())"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best parameters for Linear Regression: {}\n",
"Best parameters for Random Forest Regressor: {'model__max_depth': 10, 'model__n_estimators': 300}\n",
"Best parameters for Gradient Boosting Regressor: {'model__learning_rate': 0.2, 'model__max_depth': 5, 'model__n_estimators': 300}\n",
"Model: Linear Regression\n",
"Model: Random Forest Regressor\n",
"Model: Gradient Boosting Regressor\n"
]
}
],
"source": [
"X_reg = data1.drop(['stock index', 'year', 'index price', 'log_indexprice', 'manufacturingoutput', 'tradebalance', 'USTreasury'], axis=1)\n",
"y_reg = data1['index price']\n",
"\n",
"# Разделение данных\n",
"X_train_reg, X_test_reg, y_train_reg, y_test_reg = train_test_split(X_reg, y_reg, test_size=0.2, random_state=42)\n",
"\n",
"# Выбор моделей для регрессии\n",
"models_reg = {\n",
" 'Linear Regression': LinearRegression(),\n",
" 'Random Forest Regressor': RandomForestRegressor(random_state=42),\n",
" 'Gradient Boosting Regressor': GradientBoostingRegressor(random_state=42)\n",
"}\n",
"\n",
"# Создание конвейера для регрессии\n",
"pipelines_reg = {}\n",
"for name, model in models_reg.items():\n",
" pipelines_reg[name] = Pipeline([\n",
" ('scaler', StandardScaler()),\n",
" ('model', model)\n",
" ])\n",
"\n",
"# Определение сетки гиперпараметров для регрессии\n",
"param_grids_reg = {\n",
" 'Linear Regression': {},\n",
" 'Random Forest Regressor': {\n",
" 'model__n_estimators': [100, 200, 300],\n",
" 'model__max_depth': [None, 10, 20, 30]\n",
" },\n",
" 'Gradient Boosting Regressor': {\n",
" 'model__n_estimators': [100, 200, 300],\n",
" 'model__learning_rate': [0.01, 0.1, 0.2],\n",
" 'model__max_depth': [3, 5, 7]\n",
" }\n",
"}\n",
"\n",
"# Настройка гиперпараметров для регрессии\n",
"best_models_reg = {}\n",
"for name, pipeline in pipelines_reg.items():\n",
" grid_search = GridSearchCV(pipeline, param_grids_reg[name], cv=5, scoring='neg_mean_squared_error')\n",
" grid_search.fit(X_train_reg, y_train_reg)\n",
" best_models_reg[name] = {\n",
" 'pipeline': grid_search.best_estimator_,\n",
" 'best_params': grid_search.best_params_\n",
" }\n",
" print(f'Best parameters for {name}: {grid_search.best_params_}')\n",
"\n",
"# Обучение моделей и оценка качества\n",
"for model_name in best_models_reg.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model_pipeline = best_models_reg[model_name]['pipeline']\n",
" model_pipeline.fit(X_train_reg, y_train_reg)\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train_reg)\n",
" y_test_predict = model_pipeline.predict(X_test_reg)\n",
"\n",
" best_models_reg[model_name][\"preds_train\"] = y_train_predict\n",
" best_models_reg[model_name][\"preds_test\"] = y_test_predict\n",
"\n",
" best_models_reg[model_name][\"MSE_train\"] = mean_squared_error(y_train_reg, y_train_predict)\n",
" best_models_reg[model_name][\"MSE_test\"] = mean_squared_error(y_test_reg, y_test_predict)\n",
" best_models_reg[model_name][\"R2_train\"] = r2_score(y_train_reg, y_train_predict)\n",
" best_models_reg[model_name][\"R2_test\"] = r2_score(y_test_reg, y_test_predict)\n",
" best_models_reg[model_name][\"MAE_train\"] = mean_absolute_error(y_train_reg, y_train_predict)\n",
" best_models_reg[model_name][\"MAE_test\"] = mean_absolute_error(y_test_reg, y_test_predict)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['United States of America' 'United Kingdom' 'India' 'Japan' 'Hong Kong'\n",
" 'China' 'Germany' 'France' 'Spain']\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"Index: 262 entries, 0 to 367\n",
"Data columns (total 6 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 country 262 non-null object \n",
" 1 index price 262 non-null float64\n",
" 2 inflationrate 262 non-null float64\n",
" 3 oil prices 262 non-null float64\n",
" 4 gdppercent 262 non-null float64\n",
" 5 unemploymentrate 262 non-null float64\n",
"dtypes: float64(5), object(1)\n",
"memory usage: 14.3+ KB\n",
"gdppercent\n",
" 0.02 55\n",
" 0.03 39\n",
" 0.04 37\n",
" 0.01 30\n",
" 0.05 20\n",
" 0.08 19\n",
" 0.07 14\n",
" 0.06 12\n",
"-0.01 10\n",
" 0.09 6\n",
" 0.10 4\n",
" 0.11 3\n",
"-0.02 2\n",
"-0.03 2\n",
"-0.05 2\n",
" 0.12 1\n",
"-0.04 1\n",
"-0.10 1\n",
"-0.06 1\n",
"-0.08 1\n",
" 0.13 1\n",
" 0.14 1\n",
"Name: count, dtype: int64\n"
]
}
],
"source": [
"data2 = data_cleaned.drop(['stock index', 'year', 'log_indexprice', 'exchange_rate', 'percapitaincome', 'manufacturingoutput', 'tradebalance', 'USTreasury'], axis=1)\n",
"print(data2[\"country\"].unique())\n",
"data2.info()\n",
"print(data2[\"gdppercent\"].value_counts())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['index price', 'inflationrate', 'oil prices', 'gdppercent',\n",
" 'unemploymentrate'],\n",
" dtype='object')\n",
"Best parameters for Logistic Regression: {'model__C': 10, 'model__solver': 'lbfgs'}\n",
"Best parameters for Random Forest Classifier: {'model__max_depth': None, 'model__n_estimators': 200}\n",
"Best parameters for Gradient Boosting Classifier: {'model__learning_rate': 0.1, 'model__max_depth': 3, 'model__n_estimators': 300}\n",
"Model: Logistic Regression\n",
"Model: Random Forest Classifier\n",
"Model: Gradient Boosting Classifier\n"
]
}
],
"source": [
"# Выбор признаков и целевой переменной для классификации\n",
"X_class = data2.drop(['country'], axis=1)\n",
"y_class = data2['country']\n",
"print(X_class.columns)\n",
"# Разделение данных\n",
"X_train_class, X_test_class, y_train_class, y_test_class = train_test_split(X_class, y_class, test_size=0.2, random_state=20)\n",
"\n",
"# Выбор моделей для классификации\n",
"models_class = {\n",
" 'Logistic Regression': LogisticRegression(random_state=42, max_iter=5000, solver='liblinear'),\n",
" 'Random Forest Classifier': RandomForestClassifier(random_state=42),\n",
" 'Gradient Boosting Classifier': GradientBoostingClassifier(random_state=42)\n",
"}\n",
"\n",
"# Создание конвейера для классификации\n",
"pipelines_class = {}\n",
"for name, model in models_class.items():\n",
" pipelines_class[name] = Pipeline([\n",
" ('scaler', StandardScaler()),\n",
" ('model', model)\n",
" ])\n",
"\n",
"# Определение сетки гиперпараметров для классификации\n",
"\n",
"param_grids_class = {\n",
" 'Logistic Regression': {\n",
" 'model__C': [0.1, 1, 10],\n",
" 'model__solver': ['lbfgs', 'liblinear']\n",
" },\n",
" 'Random Forest Classifier': {\n",
" 'model__n_estimators': [100, 200, 300],\n",
" 'model__max_depth': [None, 10, 20, 30]\n",
" },\n",
" 'Gradient Boosting Classifier': {\n",
" 'model__n_estimators': [100, 200, 300],\n",
" 'model__learning_rate': [0.01, 0.1, 0.2],\n",
" 'model__max_depth': [3, 5, 7]\n",
" }\n",
"}\n",
"# Убрал определение параметров поскольку уже был предподсчет данных, но вылетела ошибка. Сохранил лучшие параметры\n",
"'''\n",
"param_grids_class = {\n",
" 'Logistic Regression': {\n",
" 'model__C': [10],\n",
" 'model__solver': ['lbfgs']\n",
" },\n",
" 'Random Forest Classifier': {\n",
" 'model__n_estimators': [200],\n",
" 'model__max_depth': [ 30]\n",
" },\n",
" 'Gradient Boosting Classifier': {\n",
" 'model__n_estimators': [200],\n",
" 'model__learning_rate': [0.1],\n",
" 'model__max_depth': [7]\n",
" }\n",
"}'''\n",
"\n",
"# Настройка гиперпараметров для классификации\n",
"best_models_class = {}\n",
"for name, pipeline in pipelines_class.items():\n",
" grid_search = GridSearchCV(pipeline, param_grids_class[name], cv=5, scoring='accuracy', n_jobs=-1)\n",
" grid_search.fit(X_train_class, y_train_class)\n",
" best_models_class[name] = {\n",
" 'pipeline': grid_search.best_estimator_,\n",
" 'best_params': grid_search.best_params_\n",
" }\n",
" print(f'Best parameters for {name}: {grid_search.best_params_}')\n",
"\n",
"# Обучение моделей и оценка качества\n",
"for model_name in best_models_class.keys():\n",
" print(f\"Model: {model_name}\")\n",
"\n",
" model_pipeline = best_models_class[model_name]['pipeline']\n",
" model_pipeline.fit(X_train_class, y_train_class)\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train_class)\n",
" y_test_probs = model_pipeline.predict_proba(X_test_class)\n",
" y_test_predict = model_pipeline.predict(X_test_class)\n",
"\n",
" best_models_class[model_name]['preds_train'] = y_train_predict\n",
" best_models_class[model_name][\"probs\"] = y_test_probs\n",
" best_models_class[model_name][\"preds\"] = y_test_predict\n",
"\n",
" best_models_class[model_name][\"Precision_train\"] = precision_score(y_train_class, y_train_predict, average='weighted')\n",
" best_models_class[model_name][\"Precision_test\"] = precision_score(y_test_class, y_test_predict, average='weighted')\n",
" best_models_class[model_name][\"Recall_train\"] = recall_score(y_train_class, y_train_predict, average='weighted')\n",
" best_models_class[model_name][\"Recall_test\"] = recall_score(y_test_class, y_test_predict, average='weighted')\n",
" best_models_class[model_name][\"Accuracy_train\"] = accuracy_score(y_train_class, y_train_predict)\n",
" best_models_class[model_name][\"Accuracy_test\"] = accuracy_score(y_test_class, y_test_predict)\n",
" best_models_class[model_name][\"ROC_AUC_test\"] = roc_auc_score(y_test_class, y_test_probs, multi_class='ovr', average='weighted')\n",
" best_models_class[model_name][\"F1_train\"] = f1_score(y_train_class, y_train_predict, average='weighted')\n",
" best_models_class[model_name][\"F1_test\"] = f1_score(y_test_class, y_test_predict, average='weighted')\n",
" best_models_class[model_name][\"MCC_test\"] = matthews_corrcoef(y_test_class, y_test_predict)\n",
" best_models_class[model_name][\"Cohen_kappa_test\"] = cohen_kappa_score(y_test_class, y_test_predict)\n",
" best_models_class[model_name][\"Confusion_matrix\"] = confusion_matrix(y_test_class, y_test_predict)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjQAAAQ9CAYAAABOVNfxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd1gUVxcH4N/Awi6wdJAiSJEiSFGxY8GowUbs3SgaS+waK1EUxRJ7id0Y0ERFY9SgiRU7lliAmICIBSGKYqGISN37/eHHxBWQXWR3Ked9nnl0ZmfuuXd2djl7584MxxhjIIQQQgipwtRUXQFCCCGEkE9FCQ0hhBBCqjxKaAghhBBS5VFCQwghhJAqjxIaQgghhFR5lNAQQgghpMqjhIYQQgghVR4lNIQQQgip8iihIYQQQkiVRwkNIaTS8vHxgY+PT4WVZ2trC39//worjwAcxyEoKEjV1SCEEhpCSNlCQ0PBcRxu3Lih6qqU6fLlywgKCkJ6erpC49ja2oLjOH7S0dFB06ZNsWvXLoXGJYSUTKDqChBCSGlOnjwp9zaXL1/GggUL4O/vDwMDA6nX4uPjoaZWcb/jGjRogGnTpgEAUlJS8MMPP2DYsGHIzc3FqFGjKixOZfb27VsIBPSnhKgeHYWEkEpLU1OzQssTCoUVWl7t2rUxZMgQft7f3x/29vZYs2aN0hOaN2/eQEdHR6kxAUAkEik9JiEloVNOhJAKExUVhc6dO0NPTw9isRjt27fH1atXi633119/oW3bttDS0oKVlRUWLVqEkJAQcByHxMREfr2SxtB8//33qF+/PrS1tWFoaIjGjRtjz549AICgoCDMmDEDAGBnZ8efDioqs6QxNOnp6Zg6dSpsbW0hFAphZWWFoUOH4sWLF3K339TUFPXq1cP9+/ellkskEqxduxb169eHSCSCmZkZxowZg7S0tGLrBQUFwdLSEtra2mjXrh1iY2OL1bvoFOD58+cxbtw41KpVC1ZWVvzrx44dQ+vWraGjowNdXV107doV//zzj1Ssp0+fYvjw4bCysoJQKISFhQW6d+8utf9v3LgBX19fmJiYQEtLC3Z2dhgxYoRUOSWNoZHlOChqQ2RkJL755huYmppCR0cHPXv2xPPnz2Xd5YTwqIeGEFIh/vnnH7Ru3Rp6enqYOXMmNDQ0sHXrVvj4+OD8+fNo1qwZAODx48do164dOI5DQEAAdHR08MMPP8jUe7J9+3ZMmjQJffr0weTJk5GTk4O//voL165dw6BBg9CrVy/cvXsXe/fuxZo1a2BiYgLgXaJRkqysLLRu3RpxcXEYMWIEGjVqhBcvXiA8PBz//vsvv72sCgoK8O+//8LQ0FBq+ZgxYxAaGorhw4dj0qRJePjwITZs2ICoqChERkZCQ0MDABAQEIDly5fDz88Pvr6+iImJga+vL3JyckqMN27cOJiammLevHl48+YNAOCnn37CsGHD4Ovri2XLliE7OxubN29Gq1atEBUVBVtbWwBA79698c8//2DixImwtbVFamoqTp06haSkJH7+888/h6mpKWbPng0DAwMkJibi4MGDH90Hsh4HRSZOnAhDQ0PMnz8fiYmJWLt2LSZMmIB9+/bJte8JASOEkDKEhIQwAOz69eulrtOjRw+mqanJ7t+/zy978uQJ09XVZW3atOGXTZw4kXEcx6KiovhlL1++ZEZGRgwAe/jwIb+8bdu2rG3btvx89+7dWf369T9a1xUrVhQrp4iNjQ0bNmwYPz9v3jwGgB08eLDYuhKJ5KNxbGxs2Oeff86eP3/Onj9/zm7fvs2+/PJLBoCNHz+eX+/ixYsMANu9e7fU9sePH5da/vTpUyYQCFiPHj2k1gsKCmIApOpd9H60atWKFRQU8Mtfv37NDAwM2KhRo6TKePr0KdPX1+eXp6WlMQBsxYoVpbbv0KFDZb7njDEGgM2fP5+fl/U4KGpDhw4dpPb11KlTmbq6OktPT/9oXEI+RKecCCGfrLCwECdPnkSPHj1gb2/PL7ewsMCgQYNw6dIlZGZmAgCOHz+OFi1aoEGDBvx6RkZGGDx4cJlxDAwM8O+//+L69esVUu9ff/0Vnp6e6NmzZ7HXOI4rc/uTJ0/C1NQUpqamcHd3x08//YThw4djxYoV/Dq//PIL9PX10bFjR7x48YKfvLy8IBaLcfbsWQBAREQECgoKMG7cOKkYEydOLDX+qFGjoK6uzs+fOnUK6enpGDhwoFQsdXV1NGvWjI+lpaUFTU1NnDt3rthpryJFA6qPHj2K/Pz8MvcFIN9xUGT06NFS+7p169YoLCzEo0ePZIpJSBFKaAghn+z58+fIzs6Gs7NzsddcXFwgkUiQnJwMAHj06BEcHByKrVfSsg/NmjULYrEYTZs2haOjI8aPH4/IyMhy1/v+/ftwc3Mr9/bNmjXDqVOncPz4caxcuRIGBgZIS0uTGsyckJCAjIwM1KpVi09+iqasrCykpqYCAP8H/MP9YGRkVOwUVhE7Ozup+YSEBADAZ599VizWyZMn+VhCoRDLli3DsWPHYGZmhjZt2mD58uV4+vQpX1bbtm3Ru3dvLFiwACYmJujevTtCQkKQm5tb6v6Q5zgoUqdOHan5oraWlmgRUhoaQ0MIqTJcXFwQHx+Po0eP4vjx4/j111+xadMmzJs3DwsWLFB6fUxMTNChQwcAgK+vL+rVq4du3bph3bp1+OabbwC8G+hbq1Yt7N69u8QyShvfIwstLS2peYlEAuDdOBpzc/Ni679/efWUKVPg5+eHw4cP48SJEwgMDMTSpUtx5swZNGzYEBzH4cCBA7h69SqOHDmCEydOYMSIEVi1ahWuXr0KsVhc7nq/7/0epvcxxiqkfFJzUEJDCPlkpqam0NbWRnx8fLHX7ty5AzU1NVhbWwMAbGxscO/evWLrlbSsJDo6Oujfvz/69++PvLw89OrVC4sXL0ZAQABEIpFMp4qK1K1bF3///bfM65ela9euaNu2LZYsWYIxY8ZAR0cHdevWxenTp+Ht7V0sAXmfjY0NgHf74f2el5cvX8rcW1G3bl0AQK1atfhEq6z1p02bhmnTpiEhIQENGjTAqlWr8PPPP/PrNG/eHM2bN8fixYuxZ88eDB48GGFhYRg5cmSx8uQ5DgipaHTKiRDyydTV1fH555/jt99+k7rs99mzZ9izZw9atWoFPT09AO96Mq5cuYLo6Gh+vVevXpXag/G+ly9fSs1ramrC1dUVjDF+nEfRvVhkuVNw7969ERMTg0OHDhV7rbw9BLNmzcLLly+xfft2AEC/fv1QWFiI4ODgYusWFBTw9Wzfvj0EAgE2b94stc6GDRtkju3r6ws9PT0sWbKkxHEvRZdDZ2dnF7tyqm7dutDV1eVPKaWlpRXbB0Xjnko77STPcUBIRaMeGkKIzH788UccP3682PLJkydj0aJFOHXqFFq1aoVx48ZBIBBg69atyM3NxfLly/l1Z86ciZ9//hkdO3bExIkT+cu269Spg1evXn20h+Xzzz+Hubk5vL29YWZmhri4OGzYsAFdu3aFrq4uAMDLywsAMGfOHAwYMAAaGhrw8/Mr8aZzM2bMwIEDB9C3b1+MGDECXl5eePXqFcLDw7FlyxZ4enrKvY86d+4MNzc3rF69GuPHj0fbtm0xZswYLF26FNHR0fj888+hoaGBhIQE/PLLL1i3bh369OkDMzMzTJ48GatWrcIXX3yBTp06ISYmBseOHYOJiYlMPU96enrYvHkzvvzySzRq1AgDBgyAqakpkpKS8Pvvv8Pb2xsbNmzA3bt30b59e/Tr1w+urq4QCAQ4dOgQnj17hgEDBgAAdu7ciU2bNqFnz56oW7cuXr9+je3bt0NPTw9dunQptQ6yHgeEVDjVXmRFCKkKii6xLW1KTk5mjDF269Yt5uvry8RiMdPW1mbt2rVjly9fLlZeVFQUa926NRMKhczKyootXbqUrV+/ngFgT58+5df78LLtrVu3sjZt2jBjY2MmFApZ3bp12YwZM1hGRoZU+cHBwax27dpMTU1N6hLuDy/bZuzdJeMTJkxgtWvXZpqamszKyooNGzaMvXjx4qP
"text/plain": [
"<Figure size 1200x1000 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"num_models = len(best_models_class)\n",
"fig, ax = plt.subplots(num_models, 1, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(best_models_class.keys()):\n",
" c_matrix = best_models_class[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=['United States of America', 'United Kingdom', 'India', 'Japan', 'Hong Kong', 'China', 'Germany', 'France', 'Spain']\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABBUAAAQ9CAYAAAAGfnVcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3xT5f7A8U+SZnSmFDpoqUALyAYFRUSGgpSh/hAQwcVygwooIFdEhoiC4GI5LuNe8Sooer2CyFaRqsgQAcGWIUjpgNJ0N21yfn/EBtJFU9Imab/v16svyMnTk+9J0jznfPM830elKIqCEEIIIYQQQgghhJPU7g5ACCGEEEIIIYQQ3kmSCkIIIYQQQgghhKgSSSoIIYQQQgghhBCiSiSpIIQQQgghhBBCiCqRpIIQQgghhBBCCCGqRJIKQgghhBBCCCGEqBJJKgghhBBCCCGEEKJKJKkghBBCCCGEEEKIKpGkghBCCCGEEEIIIapEkgqixpw6dQqVSsWqVavcHYq4CqNGjaJJkybuDqPWUKlUzJw5091huF2vXr3o1auX/bYnfl6UjFEIT+OJfzfCedLP2sycOROVSlWptjXRl0ofUHnSp9c9klQQLrFq1SpUKhW//PKLu0OpNsWdW/GPVqulSZMmPP3002RkZLg7PAEsXboUlUpFly5dqryPpKQkZs6cyYEDB1wXmIfbuXNnqfd2TEwMDz30ECdOnHB3eE7ZvXs3M2fOlL9JUetIP5vh7vDqrOL3XvGPj48PUVFRjBo1irNnz7o7PFGC9OnCHXzcHYCoOxo3bkxeXh5ardbdoVyVZcuWERAQQE5ODtu2beOdd95h37597Nq1y92h1Yj3338fq9Xq7jDKtGbNGpo0acLPP/9MYmIizZo1c3ofSUlJzJo1iyZNmtCxY0fXB+nBnn76aW644QYKCwvZt28f7733Hhs2bOC3334jMjKyRmOp6ufF7t27mTVrFqNGjSI4OLh6ghPCQ0k/Wzt4aj87e/ZsmjZtSn5+Pj/++COrVq1i165dHDp0CIPB4PLHmz59Os8//7zL91tXSJ8uapKMVBA1RqVSYTAY0Gg07g6lXLm5uVdsM3ToUB544AEee+wx1q5dy7333ssPP/zAzz//XAMRXmK1WsnPz6/RxwTQarXo9foaf9wrOXnyJLt372bRokWEhoayZs0ad4fkdbp3784DDzzA6NGjeeedd3j99ddJT09n9erV5f5OTk5OtcTiDZ8XQngab/i7kX72yjy1n+3fvz8PPPAADz/8MB988AHPPfccx48f58svv6yWx/Px8amWZEVdIX26qEmSVBA1pqz5VKNGjSIgIICzZ88yaNAgAgICCA0N5bnnnsNisTj8vtVq5c0336RNmzYYDAbCw8N57LHHuHjxokO7//73vwwcOJDIyEj0ej2xsbHMmTOn1P569epF27Zt2bt3Lz169MDPz49//OMfTh9X9+7dATh+/LjD9p9++ol+/fphNBrx8/OjZ8+e/PDDD6V+f+fOnXTu3BmDwUBsbCzvvvtumfMIVSoV48ePZ82aNbRp0wa9Xs+mTZsAOHv2LGPGjCE8PBy9Xk+bNm1YsWJFqcd65513aNOmDX5+ftSrV4/OnTvz0Ucf2e/PyspiwoQJNGnSBL1eT1hYGLfffjv79u2ztylrrmdOTg7PPvss0dHR6PV6rr32Wl5//XUURSnzGL744gvatm1rj7X4OC539OhRTp8+XdZTXqY1a9ZQr149Bg4cyNChQ8tNKmRkZDBx4kT7MTZq1IiHHnqI8+fPs3PnTm644QYARo8ebR86WPyebdKkCaNGjSq1z5Lz8sxmMzNmzKBTp04YjUb8/f3p3r07O3bsqPTxFEtJScHHx4dZs2aVuu/YsWOoVCoWL14MQGFhIbNmzaJ58+YYDAbq16/PLbfcwpYtW5x+XIDbbrsNsCVs4NLQ5CNHjnDfffdRr149brnlFnv7Dz/8kE6dOuHr60tISAjDhw/nzJkzpfb73nvvERsbi6+vLzfeeCPff/99qTblzb88evQow4YNIzQ0FF9fX6699lpeeOEFe3yTJ08GoGnTpvbX79SpU9USoxCeRvpZ6Wers58tqbzX5ejRowwdOpSQkBAMBgOdO3culXioTH9V1mtUUFDAxIkTCQ0NJTAwkLvuuou//vqrVGzl1aUoa58rV67ktttuIywsDL1eT+vWrVm2bFmlnoMrvd4lSZ++ymG79Om1h0x/EG5nsViIi4ujS5cuvP7662zdupWFCxcSGxvLE088YW/32GOPsWrVKkaPHs3TTz/NyZMnWbx4Mfv37+eHH36wD6latWoVAQEBTJo0iYCAALZv386MGTPIzMxkwYIFDo994cIF+vfvz/Dhw3nggQcIDw93Ov7iD7d69erZt23fvp3+/fvTqVMnXnrpJdRqtb3T+v7777nxxhsB2L9/P/369aNhw4bMmjULi8XC7NmzCQ0NLfOxtm/fztq1axk/fjwNGjSgSZMmpKSkcNNNN9lPJEJDQ/n6668ZO3YsmZmZTJgwAbANp3z66acZOnQozzzzDPn5+Rw8eJCffvqJ++67D4DHH3+cTz/9lPHjx9O6dWsuXLjArl27+P3337n++uvLjElRFO666y527NjB2LFj6dixI9988w2TJ0/m7NmzvPHGGw7td+3axfr163nyyScJDAzk7bffZsiQIZw+fZr69evb27Vq1YqePXuyc+fOSr0Oa9asYfDgweh0OkaMGMGyZcvYs2ePPUkAkJ2dTffu3fn9998ZM2YM119/PefPn+fLL7/kr7/+olWrVsyePZsZM2bw6KOP2k+Ybr755krFUCwzM5MPPviAESNG8Mgjj5CVlcU///lP4uLi+Pnnn52aVhEeHk7Pnj1Zu3YtL730ksN9n3zyCRqNhnvuuQewdcDz5s3j4Ycf5sYbbyQzM5NffvmFffv2cfvttzt1DHDpRPHy1wXgnnvuoXnz5rzyyiv2E9q5c+fy4osvMmzYMB5++GHS0tJ455136NGjB/v377cPW/znP//JY489xs0338yECRM4ceIEd911FyEhIURHR1cYz8GDB+nevTtarZZHH32UJk2acPz4cf73v/8xd+5cBg8ezB9//MF//vMf3njjDRo0aABg/3uqiRiF8ETSz0o/64p+tqSyXpfDhw/TrVs3oqKieP755/H392ft2rUMGjSIzz77jLvvvhuoen/18MMP8+GHH3Lfffdx8803s337dgYOHFil+IstW7aMNm3acNddd+Hj48P//vc/nnzySaxWK+PGjSv39yrzepckffol0qfXMooQLrBy5UoFUPbs2VNum5MnTyqAsnLlSvu2kSNHKoAye/Zsh7bXXXed0qlTJ/vt77//XgGUNWvWOLTbtGlTqe25ubmlHvuxxx5T/Pz8lPz8fPu2nj17KoCyfPnySh3jSy+9pADKsWPHlLS0NOXUqVPKihUrFF9fXyU0NFTJyclRFEVRrFar0rx5cyUuLk6xWq0OcTVt2lS5/fbb7dvuvPNOxc/PTzl79qx9W0JCguLj46OU/PMEFLVarRw+fNhh+9ixY5WGDRsq58+fd9g+fPhwxWg02p+P//u//1PatGlT4TEajUZl3LhxFbYZOXKk0rhxY/vtL774QgGUl19+2aHd0KFDFZVKpSQmJjocg06nc9j266+/KoDyzjvvlDrenj17VhhLsV9++UUBlC1btiiKYnsNGjVqpDzzzDMO7WbMmKEAyvr160vto/i12rNnT6n3abHGjRsrI0eOLLW9Z8+eDrEWFRUpBQUFDm0uXryohIeHK2PGjHHYDigvvfRShcf37rvvKoDy22+/OWxv3bq1ctttt9lvd+jQQRk4cGCF+yrLjh07FEBZsWKFkpaWpiQlJSkbNmxQmjRpoqhUKvvfdfHfwIgRIxx+/9SpU4pGo1Hmzp3rsP23335TfHx87NvNZrMSFhamdOzY0eH5ee+990q93mV9XvTo0UMJDAxU/vzzT4fHufzvbMGCBQqgnDx5stpjFKImST8r/ay7+tni997WrVuVtLQ05cyZM8qnn36qhIaGKnq9Xjl
"text/plain": [
"<Figure size 1200x1000 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(3, 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"ax = ax.flatten()\n",
"\n",
"for index, (name, model) in enumerate(best_models_reg.items()):\n",
" model_pipeline = model['pipeline']\n",
" y_pred_reg = model_pipeline.predict(X_test_reg)\n",
"\n",
" # График фактических значений против предсказанных значений\n",
" ax[index * 2].scatter(y_test_reg, y_pred_reg, alpha=0.5)\n",
" ax[index * 2].plot([min(y_test_reg), max(y_test_reg)], [min(y_test_reg), max(y_test_reg)], color='red', linestyle='--')\n",
" ax[index * 2].set_xlabel('Actual Values')\n",
" ax[index * 2].set_ylabel('Predicted Values')\n",
" ax[index * 2].set_title(f'{name}: Actual vs Predicted')\n",
"\n",
" # График остатков\n",
" residuals = y_test_reg - y_pred_reg\n",
" ax[index * 2 + 1].scatter(y_pred_reg, residuals, alpha=0.5)\n",
" ax[index * 2 + 1].axhline(y=0, color='red', linestyle='--')\n",
" ax[index * 2 + 1].set_xlabel('Predicted Values')\n",
" ax[index * 2 + 1].set_ylabel('Residuals')\n",
" ax[index * 2 + 1].set_title(f'{name}: Residuals vs Predicted')\n",
"\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "aimvenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}