From a1830c3724796647a1cdf547109b91e6b59fe520 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9D=D0=B8=D0=BA=D0=B8=D1=82=D0=B0=20=D0=9F=D0=BE=D1=82?= =?UTF-8?q?=D0=B0=D0=BF=D0=BE=D0=B2?= Date: Sun, 1 Dec 2024 11:09:12 +0400 Subject: [PATCH] =?UTF-8?q?=D0=97=D0=B0=D0=BA=D0=BE=D0=BD=D1=87=D0=B8?= =?UTF-8?q?=D0=BB=20=D1=80=D0=B0=D0=B1=D0=BE=D1=82=D1=83=20=D1=81=20Featur?= =?UTF-8?q?eTools?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Произвел нормализацию; Использовал категориальное и One Hot кодирование; Разбил на выборки и аугментировал; Обучил модель и исследовал результаты обучения; --- lab_3/lab3.ipynb | 372 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 340 insertions(+), 32 deletions(-) diff --git a/lab_3/lab3.ipynb b/lab_3/lab3.ipynb index 6411747..b3e06d3 100644 --- a/lab_3/lab3.ipynb +++ b/lab_3/lab3.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 228, + "execution_count": 248, "metadata": {}, "outputs": [], "source": [ @@ -33,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": 229, + "execution_count": 249, "metadata": {}, "outputs": [], "source": [ @@ -49,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": 230, + "execution_count": 250, "metadata": {}, "outputs": [ { @@ -225,7 +225,7 @@ "BMI 31.89 97.65 " ] }, - "execution_count": 230, + "execution_count": 250, "metadata": {}, "output_type": "execute_result" } @@ -244,7 +244,7 @@ }, { "cell_type": "code", - "execution_count": 231, + "execution_count": 251, "metadata": {}, "outputs": [], "source": [ @@ -266,7 +266,7 @@ }, { "cell_type": "code", - "execution_count": 232, + "execution_count": 252, "metadata": {}, "outputs": [ { @@ -584,7 +584,7 @@ "39 CovidPos False 0.0" ] }, - "execution_count": 232, + "execution_count": 252, "metadata": {}, "output_type": "execute_result" } @@ -602,7 +602,7 @@ }, { "cell_type": "code", - "execution_count": 233, + "execution_count": 253, "metadata": {}, "outputs": [], "source": [ @@ -615,7 +615,7 @@ }, { "cell_type": "code", - "execution_count": 234, + "execution_count": 254, "metadata": {}, "outputs": [], "source": [ @@ -660,7 +660,7 @@ }, { "cell_type": "code", - "execution_count": 235, + "execution_count": 255, "metadata": {}, "outputs": [ { @@ -776,7 +776,7 @@ "5 24.27 31.89 " ] }, - "execution_count": 235, + "execution_count": 255, "metadata": {}, "output_type": "execute_result" } @@ -788,7 +788,7 @@ }, { "cell_type": "code", - "execution_count": 236, + "execution_count": 256, "metadata": {}, "outputs": [], "source": [ @@ -818,7 +818,7 @@ }, { "cell_type": "code", - "execution_count": 237, + "execution_count": 257, "metadata": {}, "outputs": [ { @@ -838,7 +838,7 @@ }, { "cell_type": "code", - "execution_count": 238, + "execution_count": 258, "metadata": {}, "outputs": [], "source": [ @@ -869,7 +869,7 @@ }, { "cell_type": "code", - "execution_count": 239, + "execution_count": 259, "metadata": {}, "outputs": [], "source": [ @@ -886,7 +886,7 @@ }, { "cell_type": "code", - "execution_count": 240, + "execution_count": 260, "metadata": {}, "outputs": [ { @@ -1002,7 +1002,7 @@ "5 24.27 31.89 " ] }, - "execution_count": 240, + "execution_count": 260, "metadata": {}, "output_type": "execute_result" } @@ -1020,7 +1020,7 @@ }, { "cell_type": "code", - "execution_count": 241, + "execution_count": 261, "metadata": {}, "outputs": [ { @@ -1047,7 +1047,7 @@ }, { "cell_type": "code", - "execution_count": 242, + "execution_count": 262, "metadata": {}, "outputs": [], "source": [ @@ -1056,7 +1056,7 @@ }, { "cell_type": "code", - "execution_count": 247, + "execution_count": 263, "metadata": {}, "outputs": [ { @@ -1179,7 +1179,7 @@ "BMINorm 0.625000 1.0 " ] }, - "execution_count": 247, + "execution_count": 263, "metadata": {}, "output_type": "execute_result" } @@ -1219,7 +1219,7 @@ }, { "cell_type": "code", - "execution_count": 221, + "execution_count": 264, "metadata": {}, "outputs": [], "source": [ @@ -1228,7 +1228,7 @@ }, { "cell_type": "code", - "execution_count": 222, + "execution_count": 266, "metadata": {}, "outputs": [ { @@ -1380,7 +1380,7 @@ "# Преобразуем датасет с помощью фремйворка\n", "# https://featuretools.alteryx.com/en/stable/getting_started/afe.html\n", "\n", - "entity_set = ft.EntitySet().add_dataframe(df, \"df\", make_index=True, index=\"id\")\n", + "entity_set = ft.EntitySet().add_dataframe(df_norm, \"df\", make_index=True, index=\"id\")\n", "\n", "feature_matrix, feature_defs = ft.dfs(\n", " entityset=entity_set,\n", @@ -1401,7 +1401,7 @@ }, { "cell_type": "code", - "execution_count": 224, + "execution_count": 267, "metadata": {}, "outputs": [ { @@ -1430,15 +1430,12 @@ "\n", "\n", "\n", - "\n", - "\n", "\n", "\n", "\n", "\n", "\n", "\n", - "\n", "\n", "\n", "\n", @@ -1492,9 +1489,6 @@ "\n", "\n", "\n", - "\n", - "\n", - "\n", "\n", "\n", "\n", @@ -1508,7 +1502,13 @@ "\n", "\n", "\n", - "\n" + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n" ] } ], @@ -1523,6 +1523,314 @@ "print(\"Стало признаков:\", len(features_enc))\n", "print(*features_enc, sep='\\n')" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Разобьем данные на выборки" + ] + }, + { + "cell_type": "code", + "execution_count": 277, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import train_test_split" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Размеры выборок:\n", + "Обучающая выборка: (196817, 98)\n", + "Тестовая выборка: (24602, 98)\n", + "Контрольная выборка: (24603, 98)\n" + ] + } + ], + "source": [ + "prepared_dataset = feature_matrix_enc\n", + "\n", + "target_column = \"HadHeartAttack\"\n", + "\n", + "X = prepared_dataset.drop(columns=[target_column]) \n", + "Y = prepared_dataset[target_column] \n", + "\n", + "# Обучающая выборка\n", + "X_train, X_temp, Y_train, Y_temp = train_test_split(X, Y, test_size=0.2, random_state=None, stratify=y)\n", + "\n", + "# Тестовая и контрольная выборки\n", + "X_test, X_control, Y_test, Y_control = train_test_split(X_temp, Y_temp, test_size=0.5, random_state=None, stratify=Y_temp)\n", + "\n", + "print(\"Размеры выборок:\")\n", + "print(f\"Обучающая выборка: {X_train.shape}\")\n", + "print(f\"Тестовая выборка: {X_test.shape}\")\n", + "print(f\"Контрольная выборка: {X_control.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 317, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HadHeartAttack\n", + "False 232587\n", + "True 13435\n", + "Name: count, dtype: int64\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "# Подсчет количества объектов каждого класса\n", + "class_counts = y.value_counts()\n", + "print(class_counts)\n", + "\n", + "\n", + "class_counts_dict = class_counts.to_dict()\n", + "\n", + "keys = list(class_counts_dict.keys())\n", + "vals = list(class_counts_dict.values())\n", + "\n", + "keys[keys.index(True)] = \"Был приступ\"\n", + "keys[keys.index(False)] = \"Не было приступа\"\n", + "\n", + "# Визуализация\n", + "plt.bar(keys, vals)\n", + "plt.title(f\"Распределение классов\\n\\\"{target_column}\\\"\")\n", + "plt.xlabel(\"Класс\")\n", + "plt.ylabel(\"Количество\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 325, + "metadata": {}, + "outputs": [], + "source": [ + "from imblearn.over_sampling import RandomOverSampler\n", + "from imblearn.under_sampling import RandomUnderSampler\n", + "\n", + "def oversample(X: DataFrame, Y: Series, sampling_strategy=0.5) -> tuple[DataFrame, Series]:\n", + " sampler = RandomOverSampler(sampling_strategy=sampling_strategy)\n", + " x_over, y_over = sampler.fit_resample(X, Y)\n", + " return x_over, y_over \n", + "\n", + "def undersample(X: DataFrame, Y: Series, sampling_strategy=1) -> tuple[DataFrame, Series]:\n", + " sampler = RandomUnderSampler(sampling_strategy=sampling_strategy)\n", + " x_over, y_over = sampler.fit_resample(X, Y)\n", + " return x_over, y_over " + ] + }, + { + "cell_type": "code", + "execution_count": 327, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Данные до аугментации в обучающей выборке\n", + "HadHeartAttack\n", + "False 186069\n", + "True 10748\n", + "Name: count, dtype: int64\n", + "\n", + "Данные после аугментации в обучающей выборке\n", + "HadHeartAttack\n", + "False 10748\n", + "True 10748\n", + "Name: count, dtype: int64\n" + ] + } + ], + "source": [ + "print(\"Данные до аугментации в обучающей выборке\")\n", + "print(Y_train.value_counts())\n", + "\n", + "X_train_samplied, Y_train_samplied = X_train, Y_train\n", + "\n", + "# X_train_samplied, Y_train_samplied = oversample(X_train_samplied, Y_train_samplied)\n", + "X_train_samplied, Y_train_samplied = undersample(X_train_samplied, Y_train_samplied)\n", + "print()\n", + "print(\"Данные после аугментации в обучающей выборке\")\n", + "print(Y_train_samplied.value_counts())" + ] + }, + { + "cell_type": "code", + "execution_count": 349, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def show_distribution(df: Series, column_name=\"\") -> None:\n", + " plt.pie(\n", + " df.value_counts(),\n", + " labels=class_counts.index,\n", + " autopct='%1.1f%%',\n", + " colors=['lightblue', 'pink'],\n", + " startangle=45,\n", + " explode=(0, 0.05)\n", + " )\n", + " plt.title(\"Распределение классов\" + (f\"\\n\\\"{column_name}\\\"\" if column_name else \"\"))\n", + " plt.show()\n", + "\n", + "show_distribution(Y_train_samplied, column_name=target_column)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Обучение модели" + ] + }, + { + "cell_type": "code", + "execution_count": 356, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix, classification_report\n", + "import seaborn as sns" + ] + }, + { + "cell_type": "code", + "execution_count": 352, + "metadata": {}, + "outputs": [], + "source": [ + "model = RandomForestClassifier()\n", + "\n", + "start_time = time.time()\n", + "\n", + "model.fit(X_train, Y_train)\n", + "\n", + "train_time = time.time() - start_time" + ] + }, + { + "cell_type": "code", + "execution_count": 353, + "metadata": {}, + "outputs": [], + "source": [ + "Y_pred = model.predict(X_test)\n", + "Y_pred_proba = model.predict_proba(X_test)[:, 1]" + ] + }, + { + "cell_type": "code", + "execution_count": 360, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Время обучения модели: 51.06 секунд\n", + "ROC-AUC: 0.87\n", + "F1-Score: 0.23\n", + "Матрица ошибок:\n", + "[[23151 108]\n", + " [ 1155 188]]\n", + "Отчет по классификации:\n", + " precision recall f1-score support\n", + "\n", + " False 0.95 1.00 0.97 23259\n", + " True 0.64 0.14 0.23 1343\n", + "\n", + " accuracy 0.95 24602\n", + " macro avg 0.79 0.57 0.60 24602\n", + "weighted avg 0.94 0.95 0.93 24602\n", + "\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Метрики\n", + "roc_auc = roc_auc_score(Y_test, Y_pred_proba)\n", + "f1 = f1_score(Y_test, Y_pred)\n", + "\n", + "conf_matrix = confusion_matrix(Y_test, Y_pred)\n", + "class_report = classification_report(Y_test, Y_pred)\n", + "\n", + "# Вывод результатов\n", + "print(f'Время обучения модели: {train_time:.2f} секунд')\n", + "print(f'ROC-AUC: {roc_auc:.2f}')\n", + "print(f'F1-Score: {f1:.2f}')\n", + "print('Матрица ошибок:')\n", + "print(conf_matrix)\n", + "print('Отчет по классификации:')\n", + "print(class_report)\n", + "\n", + "# Визуализация матрицы ошибок\n", + "plt.figure(figsize=(7, 7))\n", + "sns.heatmap(\n", + " conf_matrix,\n", + " annot=True,\n", + " fmt='d',\n", + " cmap='Blues',\n", + " xticklabels=['Нет приступа', 'Был приступ'],\n", + " yticklabels=['Нет приступа', 'Был приступ']\n", + ")\n", + "plt.title('Матрица ошибок')\n", + "plt.xlabel('Предсказанный класс')\n", + "plt.ylabel('Истинный класс')\n", + "plt.show()" + ] } ], "metadata": {