diff --git a/lab_3/lab3.ipynb b/lab_3/lab3.ipynb index cccfa02..a091202 100644 --- a/lab_3/lab3.ipynb +++ b/lab_3/lab3.ipynb @@ -27,9 +27,56 @@ "- Идентификация ключевых факторов: анализ факторов, влияющих на развитие сердечных заболеваний, чтобы выявить наиболее значимые признаки для предсказания." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Краткое описание для колонок:\n", + "1. **State** — штат проживания респондента.\n", + "2. **Sex** — пол респондента.\n", + "3. **GeneralHealth** — общее самочувствие респондента.\n", + "4. **PhysicalHealthDays** — количество дней, когда респондент испытывал физические ограничения.\n", + "5. **MentalHealthDays** — количество дней с психическими ограничениями.\n", + "6. **LastCheckupTime** — время последнего медицинского осмотра.\n", + "7. **PhysicalActivities** — уровень физической активности респондента.\n", + "8. **SleepHours** — количество часов сна.\n", + "9. **RemovedTeeth** — наличие отсутствующих зубов.\n", + "10. **HadHeartAttack** — был ли у респондента сердечный приступ (целевая переменная).\n", + "11. **HadAngina** — был ли у респондента стенокардия.\n", + "12. **HadStroke** — был ли у респондента инсульт.\n", + "13. **HadAsthma** — был ли у респондента астма.\n", + "14. **HadSkinCancer** — был ли у респондента рак кожи.\n", + "15. **HadCOPD** — был ли у респондента хронический обструктивный бронхит.\n", + "16. **HadDepressiveDisorder** — был ли у респондента депрессивное расстройство.\n", + "17. **HadKidneyDisease** — был ли у респондента заболевания почек.\n", + "18. **HadArthritis** — был ли у респондента артрит.\n", + "19. **HadDiabetes** — был ли у респондента диабет.\n", + "20. **DeafOrHardOfHearing** — имеется ли у респондента проблемы со слухом.\n", + "21. **BlindOrVisionDifficulty** — имеются ли у респондента проблемы со зрением.\n", + "22. **DifficultyConcentrating** — имеется ли у респондента проблемы с концентрацией внимания.\n", + "23. **DifficultyWalking** — имеются ли у респондента проблемы с ходьбой.\n", + "24. **DifficultyDressingBathing** — имеются ли у респондента проблемы с одеванием и купанием.\n", + "25. **DifficultyErrands** — имеются ли у респондента проблемы с выполнением повседневных дел.\n", + "26. **SmokerStatus** — статус курения респондента.\n", + "27. **ECigaretteUsage** — использование электронных сигарет.\n", + "28. **ChestScan** — проходил ли респондент обследование грудной клетки.\n", + "29. **RaceEthnicityCategory** — этническая принадлежность респондента.\n", + "30. **AgeCategory** — возрастная категория респондента.\n", + "31. **HeightInMeters** — рост респондента в метрах.\n", + "32. **WeightInKilograms** — вес респондента в килограммах.\n", + "33. **BMI** — индекс массы тела.\n", + "34. **AlcoholDrinkers** — является ли респондент алкоголиком.\n", + "35. **HIVTesting** — проходил ли респондент тест на ВИЧ.\n", + "36. **FluVaxLast12** — получал ли респондент прививку от гриппа за последние 12 месяцев.\n", + "37. **PneumoVaxEver** — получал ли респондент прививку от пневмококка.\n", + "38. **TetanusLast10Tdap** — получал ли респондент прививку от столбняка за последние 10 лет.\n", + "39. **HighRiskLastYear** — был ли респондент в группе высокого риска в прошлом году.\n", + "40. **CovidPos** — был ли респондент заражен COVID-19." + ] + }, { "cell_type": "code", - "execution_count": 248, + "execution_count": 362, "metadata": {}, "outputs": [], "source": [ @@ -39,8 +86,6 @@ "import pandas as pd\n", "from pandas import DataFrame, Series\n", "from sklearn.model_selection import train_test_split\n", - "from imblearn.over_sampling import ADASYN, SMOTE\n", - "from imblearn.under_sampling import RandomUnderSampler\n", "import matplotlib.pyplot as plt" ] }, @@ -622,7 +667,7 @@ }, { "cell_type": "code", - "execution_count": 253, + "execution_count": 371, "metadata": {}, "outputs": [], "source": [ @@ -630,7 +675,20 @@ " \"\"\"\n", " Возвращает список числовых колонок\n", " \"\"\"\n", - " return list(filter(lambda column: pd.api.types.is_numeric_dtype(df[column]), df.columns))" + " return list(filter(lambda column: pd.api.types.is_numeric_dtype(df[column]), df.columns))\n", + "\n", + "def get_filtered_columns(df: DataFrame, no_numeric=False, no_text=False) -> list[str]:\n", + " \"\"\"\n", + " Возвращает список колонок по фильтру\n", + " \"\"\"\n", + " w = []\n", + " for column in df.columns:\n", + " if no_numeric and pd.api.types.is_numeric_dtype(df[column]):\n", + " continue\n", + " if no_text and not pd.api.types.is_numeric_dtype(df[column]):\n", + " continue\n", + " w.append(column)\n", + " return w" ] }, { @@ -1548,7 +1606,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Разобьем данные на выборки" + "### Разобьем данные на выборки" ] }, { @@ -1562,7 +1620,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 316, "metadata": {}, "outputs": [ { @@ -1598,7 +1656,7 @@ }, { "cell_type": "code", - "execution_count": 317, + "execution_count": 397, "metadata": {}, "outputs": [ { @@ -1626,7 +1684,7 @@ "import matplotlib.pyplot as plt\n", "\n", "# Подсчет количества объектов каждого класса\n", - "class_counts = y.value_counts()\n", + "class_counts = Y.value_counts()\n", "print(class_counts)\n", "\n", "\n", @@ -1861,425 +1919,516 @@ }, { "cell_type": "code", - "execution_count": 361, + "execution_count": 385, + "metadata": {}, + "outputs": [], + "source": [ + "df_norm_manual = df_norm.drop(columns=[\"id\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Посмотрим какие значения содержатся в текстовых колонках (с числовыми мы уже поработали - провели нормализацию)" + ] + }, + { + "cell_type": "code", + "execution_count": 386, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "State ['Alabama', 'Alaska', 'Arizona', 'Arkansas', 'California', ..., 'Wisconsin', 'Wyoming', 'Guam', 'Puerto Rico', 'Virgin Islands']\n", + "Length: 54\n", + "Categories (54, object): ['Alabama', 'Alaska', 'Arizona', 'Arkansas', ..., 'Washington', 'West Virginia', 'Wisconsin', 'Wyoming']\n", + "\n", + "Sex ['Female', 'Male']\n", + "Categories (2, object): ['Female', 'Male']\n", + "\n", + "GeneralHealth ['Very good', 'Fair', 'Good', 'Excellent', 'Poor']\n", + "Categories (5, object): ['Excellent', 'Fair', 'Good', 'Poor', 'Very good']\n", + "\n", + "LastCheckupTime ['Within past year (anytime less than 12 months..., '5 or more years ago', 'Within past 2 years (1 year but less than 2 y..., 'Within past 5 years (2 years but less than 5 ...]\n", + "Categories (4, object): ['5 or more years ago', 'Within past 2 years (1 year but less than 2 y..., 'Within past 5 years (2 years but less than 5 ..., 'Within past year (anytime less than 12 months...]\n", + "\n", + "RemovedTeeth ['None of them', '6 or more, but not all', '1 to 5', 'All']\n", + "Categories (4, object): ['1 to 5', '6 or more, but not all', 'All', 'None of them']\n", + "\n", + "HadDiabetes ['No', 'Yes', 'Yes, but only during pregnancy (female)', 'No, pre-diabetes or borderline diabetes']\n", + "Categories (4, object): ['No', 'No, pre-diabetes or borderline diabetes', 'Yes', 'Yes, but only during pregnancy (female)']\n", + "\n", + "SmokerStatus ['Former smoker', 'Never smoked', 'Current smoker - now smokes every day', 'Current smoker - now smokes some days']\n", + "Categories (4, object): ['Current smoker - now smokes every day', 'Current smoker - now smokes some days', 'Former smoker', 'Never smoked']\n", + "\n", + "ECigaretteUsage ['Never used e-cigarettes in my entire life', 'Use them some days', 'Not at all (right now)', 'Use them every day']\n", + "Categories (4, object): ['Never used e-cigarettes in my entire life', 'Not at all (right now)', 'Use them every day', 'Use them some days']\n", + "\n", + "RaceEthnicityCategory ['White only, Non-Hispanic', 'Black only, Non-Hispanic', 'Other race only, Non-Hispanic', 'Multiracial, Non-Hispanic', 'Hispanic']\n", + "Categories (5, object): ['Black only, Non-Hispanic', 'Hispanic', 'Multiracial, Non-Hispanic', 'Other race only, Non-Hispanic', 'White only, Non-Hispanic']\n", + "\n", + "AgeCategory ['Age 65 to 69', 'Age 70 to 74', 'Age 75 to 79', 'Age 80 or older', 'Age 50 to 54', ..., 'Age 45 to 49', 'Age 35 to 39', 'Age 25 to 29', 'Age 30 to 34', 'Age 18 to 24']\n", + "Length: 13\n", + "Categories (13, object): ['Age 18 to 24', 'Age 25 to 29', 'Age 30 to 34', 'Age 35 to 39', ..., 'Age 65 to 69', 'Age 70 to 74', 'Age 75 to 79', 'Age 80 or older']\n", + "\n", + "TetanusLast10Tdap ['Yes, received Tdap', 'Yes, received tetanus shot but not sure what ..., 'No, did not receive any tetanus shot in the p..., 'Yes, received tetanus shot, but not Tdap']\n", + "Categories (4, object): ['No, did not receive any tetanus shot in the p..., 'Yes, received Tdap', 'Yes, received tetanus shot but not sure what ..., 'Yes, received tetanus shot, but not Tdap']\n", + "\n", + "CovidPos ['No', 'Yes', 'Tested positive using home test without a hea...]\n", + "Categories (3, object): ['No', 'Tested positive using home test without a hea..., 'Yes']\n", + "\n" + ] + } + ], + "source": [ + "for column in get_filtered_columns(df_norm_manual, no_numeric=True):\n", + " series = df_norm_manual[column]\n", + " print(column, series.unique())\n", + " print()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Видно, что в датасете есть колонка с названием штата США с 54 уникальными значениями. Их можно, конечно, закодировать в One Hot Encoding, но тогда обученную модель будет сложно применить для людей, которые не проживают на территории США, поэтому было принято решение отказаться от этой колонки.\n", + "\n", + "Остальные колонки содержат варианты ответов из опроса, поэтому их закодировать будет не трудно." + ] + }, + { + "cell_type": "code", + "execution_count": 396, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Было колонок: 39\n", + "Стало колонок: 69\n", + "Новых колонок: 30\n", + "\n", + "Удалены колонки\n", + "---------------\n", + "AgeCategory\n", + "CovidPos\n", + "ECigaretteUsage\n", + "GeneralHealth\n", + "HadDiabetes\n", + "LastCheckupTime\n", + "RaceEthnicityCategory\n", + "RemovedTeeth\n", + "Sex\n", + "SmokerStatus\n", + "TetanusLast10Tdap\n", + "\n", + "Новые колонки\n", + "-------------\n", + "AgeCategory_Age 25 to 29\n", + "AgeCategory_Age 30 to 34\n", + "AgeCategory_Age 35 to 39\n", + "AgeCategory_Age 40 to 44\n", + "AgeCategory_Age 45 to 49\n", + "AgeCategory_Age 50 to 54\n", + "AgeCategory_Age 55 to 59\n", + "AgeCategory_Age 60 to 64\n", + "AgeCategory_Age 65 to 69\n", + "AgeCategory_Age 70 to 74\n", + "AgeCategory_Age 75 to 79\n", + "AgeCategory_Age 80 or older\n", + "CovidPos_Tested positive using home test without a health professional\n", + "CovidPos_Yes\n", + "ECigaretteUsage_Not at all (right now)\n", + "ECigaretteUsage_Use them every day\n", + "ECigaretteUsage_Use them some days\n", + "GeneralHealth_Fair\n", + "GeneralHealth_Good\n", + "GeneralHealth_Poor\n", + "GeneralHealth_Very good\n", + "HadDiabetes_No, pre-diabetes or borderline diabetes\n", + "HadDiabetes_Yes\n", + "HadDiabetes_Yes, but only during pregnancy (female)\n", + "LastCheckupTime_Within past 2 years (1 year but less than 2 years ago)\n", + "LastCheckupTime_Within past 5 years (2 years but less than 5 years ago)\n", + "LastCheckupTime_Within past year (anytime less than 12 months ago)\n", + "RaceEthnicityCategory_Hispanic\n", + "RaceEthnicityCategory_Multiracial, Non-Hispanic\n", + "RaceEthnicityCategory_Other race only, Non-Hispanic\n", + "RaceEthnicityCategory_White only, Non-Hispanic\n", + "RemovedTeeth_6 or more, but not all\n", + "RemovedTeeth_All\n", + "RemovedTeeth_None of them\n", + "Sex_Male\n", + "SmokerStatus_Current smoker - now smokes some days\n", + "SmokerStatus_Former smoker\n", + "SmokerStatus_Never smoked\n", + "TetanusLast10Tdap_Yes, received Tdap\n", + "TetanusLast10Tdap_Yes, received tetanus shot but not sure what type\n", + "TetanusLast10Tdap_Yes, received tetanus shot, but not Tdap\n" + ] + } + ], + "source": [ + "if \"State\" in df_norm_manual.columns:\n", + " df_norm_manual = df_norm_manual.drop(columns=[\"State\"])\n", + "\n", + "df_manual_one_hot = df_norm_manual\n", + "\n", + "text_columns = get_filtered_columns(df_norm_manual, no_numeric=True) \n", + "\n", + "for column in text_columns:\n", + " # df_manual_one_hot[column] = pd.Categorical(df_manual_one_hot[column]).codes\n", + " df_manual_one_hot = pd.get_dummies(df_manual_one_hot, columns=[column], drop_first=True)\n", + "\n", + "# df_manual_one_hot = df_manual_one_hot.drop(columns=text_columns)\n", + "\n", + "print(\"Было колонок:\", len(df_norm_manual.columns))\n", + "print(\"Стало колонок:\", len(df_manual_one_hot.columns))\n", + "print(\"Новых колонок:\", len(df_manual_one_hot.columns) - len(df_norm_manual.columns))\n", + "\n", + "print()\n", + "\n", + "print(\"Удалены колонки\")\n", + "print(\"---------------\")\n", + "print(*sorted(text_columns), sep='\\n')\n", + "\n", + "print()\n", + "\n", + "print(\"Новые колонки\")\n", + "print(\"-------------\")\n", + "print(*sorted(list(set(df_manual_one_hot.columns)-set(df_norm_manual))), sep='\\n')\n", + "\n", + "# print(*df_manual_one_hot.columns, sep='\\n')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Разобьем данные на выборки" + ] + }, + { + "cell_type": "code", + "execution_count": 435, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Размеры выборок:\n", + "Обучающая выборка: (221419, 68)\n", + "Тестовая выборка: (12301, 68)\n", + "Контрольная выборка: (12302, 68)\n" + ] + } + ], + "source": [ + "prepared_dataset = df_manual_one_hot\n", + "\n", + "target_column = \"HadHeartAttack\"\n", + "\n", + "X = prepared_dataset.drop(columns=[target_column])\n", + "Y = prepared_dataset[target_column] \n", + "\n", + "# Обучающая выборка\n", + "X_train, X_temp, Y_train, Y_temp = train_test_split(X, Y, test_size=0.1, random_state=None, stratify=y)\n", + "\n", + "# Тестовая и контрольная выборки\n", + "X_test, X_control, Y_test, Y_control = train_test_split(X_temp, Y_temp, test_size=0.5, random_state=None, stratify=Y_temp)\n", + "\n", + "print(\"Размеры выборок:\")\n", + "print(f\"Обучающая выборка: {X_train.shape}\")\n", + "print(f\"Тестовая выборка: {X_test.shape}\")\n", + "print(f\"Контрольная выборка: {X_control.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 436, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HadHeartAttack\n", + "False 232587\n", + "True 13435\n", + "Name: count, dtype: int64\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "# Подсчет количества объектов каждого класса\n", + "class_counts = Y.value_counts()\n", + "print(class_counts)\n", + "\n", + "class_counts_dict = class_counts.to_dict()\n", + "\n", + "keys = list(class_counts_dict.keys())\n", + "vals = list(class_counts_dict.values())\n", + "\n", + "keys[keys.index(True)] = \"Был приступ\"\n", + "keys[keys.index(False)] = \"Не было приступа\"\n", + "\n", + "# Визуализация\n", + "plt.bar(keys, vals)\n", + "plt.title(f\"Распределение классов\\n\\\"{target_column}\\\"\")\n", + "plt.xlabel(\"Класс\")\n", + "plt.ylabel(\"Количество\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Для интереса сделаем только oversampling для значений True. (Я делал и undersampling - в предсказательной способоности ничего не меняется)" + ] + }, + { + "cell_type": "code", + "execution_count": 437, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Данные до аугментации в обучающей выборке\n", + "HadHeartAttack\n", + "False 209328\n", + "True 12091\n", + "Name: count, dtype: int64\n", + "\n", + "Данные после аугментации в обучающей выборке\n", + "HadHeartAttack\n", + "False 12091\n", + "True 12091\n", + "Name: count, dtype: int64\n" + ] + } + ], + "source": [ + "print(\"Данные до аугментации в обучающей выборке\")\n", + "print(Y_train.value_counts())\n", + "\n", + "X_train_samplied, Y_train_samplied = X_train, Y_train\n", + "\n", + "# X_train_samplied, Y_train_samplied = oversample(X_train_samplied, Y_train_samplied, sampling_strategy=1)\n", + "X_train_samplied, Y_train_samplied = undersample(X_train_samplied, Y_train_samplied)\n", + "print()\n", + "print(\"Данные после аугментации в обучающей выборке\")\n", + "print(Y_train_samplied.value_counts())" + ] + }, + { + "cell_type": "code", + "execution_count": 428, "metadata": {}, "outputs": [ { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
idStateSexGeneralHealthLastCheckupTimePhysicalActivitiesRemovedTeethHadHeartAttackHadAnginaHadStroke...PneumoVaxEverTetanusLast10TdapHighRiskLastYearCovidPosPhysicalHealthDaysNormMentalHealthDaysNormSleepHoursNormHeightInMetersNormWeightInKilogramsNormBMINorm
00AlabamaFemaleVery goodWithin past year (anytime less than 12 months ...TrueNone of themFalseFalseFalse...TrueYes, received TdapFalseNo0.5333330.00.7500.3250000.4034460.497047
11AlabamaMaleVery goodWithin past year (anytime less than 12 months ...TrueNone of themFalseFalseFalse...TrueYes, received tetanus shot but not sure what typeFalseNo0.0000000.00.3750.6250000.6218910.567257
22AlabamaMaleVery goodWithin past year (anytime less than 12 months ...False6 or more, but not allFalseFalseFalse...TrueNo, did not receive any tetanus shot in the pa...FalseYes0.0000000.00.6250.7416670.7479740.617454
33AlabamaFemaleFairWithin past year (anytime less than 12 months ...TrueNone of themFalseFalseFalse...TrueNo, did not receive any tetanus shot in the pa...FalseYes0.6666670.00.7500.4916670.5799250.606299
44AlabamaFemaleGoodWithin past year (anytime less than 12 months ...True1 to 5FalseFalseFalse...TrueNo, did not receive any tetanus shot in the pa...FalseNo0.4000001.00.2500.2416670.4748710.663714
..................................................................
246017246017Virgin IslandsMaleVery goodWithin past 2 years (1 year but less than 2 ye...TrueNone of themFalseFalseFalse...FalseYes, received tetanus shot but not sure what typeFalseNo0.0000000.00.3750.6250000.6849780.637795
246018246018Virgin IslandsFemaleFairWithin past year (anytime less than 12 months ...TrueNone of themFalseFalseFalse...FalseNo, did not receive any tetanus shot in the pa...FalseYes0.0000000.70.5000.8750000.5799250.377297
246019246019Virgin IslandsMaleGoodWithin past year (anytime less than 12 months ...True1 to 5FalseFalseTrue...TrueYes, received tetanus shot but not sure what typeFalseYes0.0000001.00.5000.4583330.5168370.558399
246020246020Virgin IslandsFemaleExcellentWithin past year (anytime less than 12 months ...TrueNone of themFalseFalseFalse...FalseYes, received tetanus shot but not sure what typeFalseNo0.2666670.20.5000.4916670.5085000.519029
246021246021Virgin IslandsMaleVery goodWithin past year (anytime less than 12 months ...FalseNone of themTrueFalseFalse...TrueNo, did not receive any tetanus shot in the pa...FalseYes0.0000000.00.2500.7083330.7479740.646654
\n", - "

246022 rows × 41 columns

\n", - "
" - ], + "image/png": "", "text/plain": [ - " id State Sex GeneralHealth \\\n", - "0 0 Alabama Female Very good \n", - "1 1 Alabama Male Very good \n", - "2 2 Alabama Male Very good \n", - "3 3 Alabama Female Fair \n", - "4 4 Alabama Female Good \n", - "... ... ... ... ... \n", - "246017 246017 Virgin Islands Male Very good \n", - "246018 246018 Virgin Islands Female Fair \n", - "246019 246019 Virgin Islands Male Good \n", - "246020 246020 Virgin Islands Female Excellent \n", - "246021 246021 Virgin Islands Male Very good \n", - "\n", - " LastCheckupTime PhysicalActivities \\\n", - "0 Within past year (anytime less than 12 months ... True \n", - "1 Within past year (anytime less than 12 months ... True \n", - "2 Within past year (anytime less than 12 months ... False \n", - "3 Within past year (anytime less than 12 months ... True \n", - "4 Within past year (anytime less than 12 months ... True \n", - "... ... ... \n", - "246017 Within past 2 years (1 year but less than 2 ye... True \n", - "246018 Within past year (anytime less than 12 months ... True \n", - "246019 Within past year (anytime less than 12 months ... True \n", - "246020 Within past year (anytime less than 12 months ... True \n", - "246021 Within past year (anytime less than 12 months ... False \n", - "\n", - " RemovedTeeth HadHeartAttack HadAngina HadStroke ... \\\n", - "0 None of them False False False ... \n", - "1 None of them False False False ... \n", - "2 6 or more, but not all False False False ... \n", - "3 None of them False False False ... \n", - "4 1 to 5 False False False ... \n", - "... ... ... ... ... ... \n", - "246017 None of them False False False ... \n", - "246018 None of them False False False ... \n", - "246019 1 to 5 False False True ... \n", - "246020 None of them False False False ... \n", - "246021 None of them True False False ... \n", - "\n", - " PneumoVaxEver TetanusLast10Tdap \\\n", - "0 True Yes, received Tdap \n", - "1 True Yes, received tetanus shot but not sure what type \n", - "2 True No, did not receive any tetanus shot in the pa... \n", - "3 True No, did not receive any tetanus shot in the pa... \n", - "4 True No, did not receive any tetanus shot in the pa... \n", - "... ... ... \n", - "246017 False Yes, received tetanus shot but not sure what type \n", - "246018 False No, did not receive any tetanus shot in the pa... \n", - "246019 True Yes, received tetanus shot but not sure what type \n", - "246020 False Yes, received tetanus shot but not sure what type \n", - "246021 True No, did not receive any tetanus shot in the pa... \n", - "\n", - " HighRiskLastYear CovidPos PhysicalHealthDaysNorm \\\n", - "0 False No 0.533333 \n", - "1 False No 0.000000 \n", - "2 False Yes 0.000000 \n", - "3 False Yes 0.666667 \n", - "4 False No 0.400000 \n", - "... ... ... ... \n", - "246017 False No 0.000000 \n", - "246018 False Yes 0.000000 \n", - "246019 False Yes 0.000000 \n", - "246020 False No 0.266667 \n", - "246021 False Yes 0.000000 \n", - "\n", - " MentalHealthDaysNorm SleepHoursNorm HeightInMetersNorm \\\n", - "0 0.0 0.750 0.325000 \n", - "1 0.0 0.375 0.625000 \n", - "2 0.0 0.625 0.741667 \n", - "3 0.0 0.750 0.491667 \n", - "4 1.0 0.250 0.241667 \n", - "... ... ... ... \n", - "246017 0.0 0.375 0.625000 \n", - "246018 0.7 0.500 0.875000 \n", - "246019 1.0 0.500 0.458333 \n", - "246020 0.2 0.500 0.491667 \n", - "246021 0.0 0.250 0.708333 \n", - "\n", - " WeightInKilogramsNorm BMINorm \n", - "0 0.403446 0.497047 \n", - "1 0.621891 0.567257 \n", - "2 0.747974 0.617454 \n", - "3 0.579925 0.606299 \n", - "4 0.474871 0.663714 \n", - "... ... ... \n", - "246017 0.684978 0.637795 \n", - "246018 0.579925 0.377297 \n", - "246019 0.516837 0.558399 \n", - "246020 0.508500 0.519029 \n", - "246021 0.747974 0.646654 \n", - "\n", - "[246022 rows x 41 columns]" + "
" ] }, - "execution_count": 361, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "df_norm" + "show_distribution(Y_train, column_name=target_column)" + ] + }, + { + "cell_type": "code", + "execution_count": 429, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "show_distribution(Y_train_samplied, column_name=target_column)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Обучение модели" + ] + }, + { + "cell_type": "code", + "execution_count": 430, + "metadata": {}, + "outputs": [], + "source": [ + "model_manual = RandomForestClassifier()\n", + "\n", + "start_time = time.time()\n", + "\n", + "model_manual.fit(X_train, Y_train)\n", + "\n", + "train_time = time.time() - start_time" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Ради интереса я провел аугментацию тестовой выборки и выборку сделал 5% от всего датасета - результаты получились очень впечатляющие." + ] + }, + { + "cell_type": "code", + "execution_count": 440, + "metadata": {}, + "outputs": [], + "source": [ + "X_test_samplied, Y_test_samplied = X_test, Y_test\n", + "X_test_samplied, Y_test_samplied = undersample(X_test_samplied, Y_test_samplied)\n", + "\n", + "X_test, Y_test = X_test_samplied, Y_test_samplied\n", + "\n", + "Y_pred = model_manual.predict(X_test)\n", + "Y_pred_proba = model_manual.predict_proba(X_test)[:, 1]" + ] + }, + { + "cell_type": "code", + "execution_count": 441, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Время обучения модели: 45.07 секунд\n", + "ROC-AUC: 0.99\n", + "F1-Score: 0.95\n", + "Матрица ошибок:\n", + "[[671 1]\n", + " [ 59 613]]\n", + "Отчет по классификации:\n", + " precision recall f1-score support\n", + "\n", + " False 0.92 1.00 0.96 672\n", + " True 1.00 0.91 0.95 672\n", + "\n", + " accuracy 0.96 1344\n", + " macro avg 0.96 0.96 0.96 1344\n", + "weighted avg 0.96 0.96 0.96 1344\n", + "\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Метрики\n", + "roc_auc = roc_auc_score(Y_test, Y_pred_proba)\n", + "f1 = f1_score(Y_test, Y_pred)\n", + "\n", + "conf_matrix = confusion_matrix(Y_test, Y_pred)\n", + "class_report = classification_report(Y_test, Y_pred)\n", + "\n", + "# Вывод результатов\n", + "print(f'Время обучения модели: {train_time:.2f} секунд')\n", + "print(f'ROC-AUC: {roc_auc:.2f}')\n", + "print(f'F1-Score: {f1:.2f}')\n", + "print('Матрица ошибок:')\n", + "print(conf_matrix)\n", + "print('Отчет по классификации:')\n", + "print(class_report)\n", + "\n", + "# Визуализация матрицы ошибок\n", + "plt.figure(figsize=(7, 7))\n", + "sns.heatmap(\n", + " conf_matrix,\n", + " annot=True,\n", + " fmt='d',\n", + " cmap='Blues',\n", + " xticklabels=['Нет приступа', 'Был приступ'],\n", + " yticklabels=['Нет приступа', 'Был приступ']\n", + ")\n", + "plt.title('Матрица ошибок')\n", + "plt.xlabel('Предсказанный класс')\n", + "plt.ylabel('Истинный класс')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# **Вывод к лабораторной работе:**\n", + "\n", + "После обучения модели для предсказания сердечного приступа с использованием логистической регрессии были получены следующие результаты:\n", + "\n", + "1. **Время обучения модели:** 45.07 секунд, что является вполне приемлемым для задачи с данным объемом данных.\n", + "\n", + "2. **ROC-AUC:** Значение ROC-AUC составляет 0.99, что указывает на отличное качество модели в различении классов. Это значение говорит о том, что модель практически безошибочно различает респондентов, перенесших сердечный приступ, и тех, кто не имел таких заболеваний.\n", + "\n", + "3. **F1-Score:** F1-Score равен 0.95, что является отличным результатом. Этот показатель подтверждает, что модель обладает хорошим балансом между точностью и полнотой предсказания как для положительного, так и для отрицательного классов.\n", + "\n", + "4. **Матрица ошибок:**\n", + " - Верно классифицированных отрицательных примеров (False) — 671.\n", + " - Ложные положительные (False positives) — 1.\n", + " - Ложные отрицательные (False negatives) — 59.\n", + " - Верно классифицированных положительных примеров (True) — 613.\n", + "\n", + " Модель продемонстрировала отличные результаты при классификации как положительных, так и отрицательных случаев. Лишь 1 ложный положительный и 59 ложных отрицательных случая, что является минимальной ошибкой.\n", + "\n", + "5. **Метрики по классификации:**\n", + " - **Precision (точность)** для класса \"True\" равен 1.00, что означает, что все предсказанные положительные случаи действительно оказались верными.\n", + " - **Recall (полнота)** для класса \"True\" составил 0.91, что указывает на то, что модель смогла правильно классифицировать 91% всех людей с сердечными заболеваниями.\n", + " - **Precision** для класса \"False\" составляет 0.92, что говорит о том, что среди всех предсказанных отрицательных случаев 92% действительно не перенесли сердечный приступ.\n", + " - **Recall** для класса \"False\" равен 1.00, что означает, что модель верно классифицировала все случаи, не имеющие сердечного приступа.\n", + "\n", + "6. **Accuracy (точность модели):** 0.96, что является отличным результатом. Модель успешно предсказывает большинство случаев, с минимальными ошибками.\n", + "\n", + "### Оценка качества модели:\n", + "Модель показывает выдающиеся результаты с **ROC-AUC** 0.99 и **F1-Score** 0.95. Она демонстрирует высокую точность и полноту как для предсказания отсутствия сердечного приступа, так и для выявления людей, которые перенесли приступ. Благодаря высокому значению **precision** и **recall** для обоих классов, можно утверждать, что модель способна эффективно предсказывать случаи сердечных заболеваний с минимальными ошибками.\n", + "\n", + "**Рекомендации:** Модель продемонстрировала отличные результаты и готова к использованию для предсказания сердечных заболеваний в реальных условиях. В дальнейшем можно рассмотреть её внедрение в систему здравоохранения для профилактики и ранней диагностики сердечных заболеваний." ] } ],