AIM-PIbd-32-Shabunov-O-A/lab_4/lab4.ipynb

2186 lines
157 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Лабораторная работа №4. Обучение с учителем.\n",
"\n",
"## Датасет \"Набор данных для анализа и прогнозирования сердечного приступа\".\n",
"\n",
"[**Ссылка**](https://www.kaggle.com/datasets/kamilpytlak/personal-key-indicators-of-heart-disease)\n",
"\n",
"### Описание датасета\n",
"\n",
"**Проблемная область**: Датасет связан с медицинской статистикой и направлен на анализ факторов, связанных с риском сердечного приступа. Это важно для прогнозирования и разработки стратегий профилактики сердечно-сосудистых заболеваний.\n",
"\n",
"**Актуальность**: Сердечно-сосудистые заболевания являются одной из ведущих причин смертности во всем мире. Анализ данных об образе жизни, состоянии здоровья и наследственных факторах позволяет выделить ключевые предикторы, влияющие на развитие сердечно-сосудистых заболеваний. Этот датасет предоставляет инструменты для анализа таких факторов и может быть полезен в создании прогнозных моделей, направленных на снижение рисков и своевременную диагностику.\n",
"\n",
"**Объекты наблюдения**: Каждая запись представляет собой данные о человеке, включая информацию об их состоянии здоровья, образе жизни, демографических характеристиках и наличию определенных заболеваний. Объекты наблюдений — это индивидуальные пациенты.\n",
"\n",
"**Атрибуты объектов:**\n",
"- `HeartDisease` — наличие сердечного приступа (Yes/No) (целевая переменная).\n",
"- `BMI` — индекс массы тела (Body Mass Index), числовой показатель.\n",
"- `Smoking` — курение (Yes/No).\n",
"- `AlcoholDrinking` — употребление алкоголя (Yes/No).\n",
"- `Stroke` — наличие инсульта (Yes/No).\n",
"- `PhysicalHealth` — количество дней в месяц, когда физическое здоровье было неудовлетворительным.\n",
"- `MentalHealth` — количество дней в месяц, когда психическое здоровье было неудовлетворительным.\n",
"- `DiffWalking` — трудности при ходьбе (Yes/No).\n",
"- `Sex` — пол (Male/Female).\n",
"- `AgeCategory` — возрастная категория (например, 55-59, 80 or older).\n",
"- `Race` — расовая принадлежность (например, White, Black).\n",
"- `Diabetic` — наличие диабета (Yes/No/No, borderline diabetes).\n",
"- `PhysicalActivity` — физическая активность (Yes/No).\n",
"- `GenHealth` — общее состояние здоровья (от Excellent до Poor).\n",
"- `SleepTime` — среднее количество часов сна за сутки.\n",
"- `Asthma` — наличие астмы (Yes/No).\n",
"- `KidneyDisease` — наличие заболеваний почек (Yes/No).\n",
"- `SkinCancer` — наличие кожного рака (Yes/No)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" <th>BMI</th>\n",
" <th>Smoking</th>\n",
" <th>AlcoholDrinking</th>\n",
" <th>Stroke</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>DiffWalking</th>\n",
" <th>Sex</th>\n",
" <th>AgeCategory</th>\n",
" <th>Race</th>\n",
" <th>Diabetic</th>\n",
" <th>PhysicalActivity</th>\n",
" <th>GenHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>Asthma</th>\n",
" <th>KidneyDisease</th>\n",
" <th>SkinCancer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>No</td>\n",
" <td>16.60</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>3.0</td>\n",
" <td>30.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>55-59</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>5.0</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>No</td>\n",
" <td>20.34</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>80 or older</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>No</td>\n",
" <td>26.58</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>20.0</td>\n",
" <td>30.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>65-69</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Fair</td>\n",
" <td>8.0</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>No</td>\n",
" <td>24.21</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>75-79</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Good</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>No</td>\n",
" <td>23.71</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>28.0</td>\n",
" <td>0.0</td>\n",
" <td>Yes</td>\n",
" <td>Female</td>\n",
" <td>40-44</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" HeartDisease BMI Smoking AlcoholDrinking Stroke PhysicalHealth \\\n",
"0 No 16.60 Yes No No 3.0 \n",
"1 No 20.34 No No Yes 0.0 \n",
"2 No 26.58 Yes No No 20.0 \n",
"3 No 24.21 No No No 0.0 \n",
"4 No 23.71 No No No 28.0 \n",
"\n",
" MentalHealth DiffWalking Sex AgeCategory Race Diabetic \\\n",
"0 30.0 No Female 55-59 White Yes \n",
"1 0.0 No Female 80 or older White No \n",
"2 30.0 No Male 65-69 White Yes \n",
"3 0.0 No Female 75-79 White No \n",
"4 0.0 Yes Female 40-44 White No \n",
"\n",
" PhysicalActivity GenHealth SleepTime Asthma KidneyDisease SkinCancer \n",
"0 Yes Very good 5.0 Yes No Yes \n",
"1 Yes Very good 7.0 No No No \n",
"2 Yes Fair 8.0 Yes No No \n",
"3 No Good 6.0 No No Yes \n",
"4 Yes Very good 8.0 No No No "
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"df = pd.read_csv(\".//static//csv//heart_2020_cleaned.csv\")\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Бизнес-цель №1. Задача классификации\n",
"\n",
"### Описание бизнес-цели\n",
"\n",
"**Цель**: предсказание наличия сердечного приступа. Цель состоит в разработке модели, которая будет предсказывать, возникнет ли у человека сердечный приступ (признак `HeartDisease`). Это важная задача для профилактики сердечно-сосудистых заболеваний, позволяющая выявить группы риска и назначить своевременное лечение или профилактические меры.\n",
"\n",
"### Достижимый уровень качества модели\n",
"\n",
"**Основные метрики для классификации:**\n",
"\n",
"- **Accuracy** (*точность*) показывает долю правильно классифицированных примеров среди всех наблюдений. Легко интерпретируется, но может быть недостаточно информативной для несбалансированных классов.\n",
"- **F1-Score** гармоническое среднее между точностью (precision) и полнотой (recall). Подходит для задач, где важно одновременно учитывать как ложные положительные, так и ложные отрицательные ошибки, особенно при несбалансированных классах.\n",
"- **ROC AUC** (*Area Under the ROC Curve*) отражает способность модели различать положительные и отрицательные классы на всех уровнях порога вероятности. Значение от 0.5 (случайное угадывание) до 1.0 (идеальная модель). Полезна для оценки модели на несбалансированных данных.\n",
"- **Cohen's Kappa** измеряет степень согласия между предсказаниями модели и истинными метками с учётом случайного угадывания. Значения варьируются от -1 (полное несогласие) до 1 (идеальное согласие). Удобна для оценки на несбалансированных данных.\n",
"- **MCC** (*Matthews Correlation Coefficient*) метрика корреляции между предсказаниями и истинными классами, учитывающая все типы ошибок (TP, TN, FP, FN). Значение варьируется от -1 (полная несоответствие) до 1 (идеальное совпадение). Отлично подходит для задач с несбалансированными классами.\n",
"- **Confusion Matrix** (*матрица ошибок*) матрица ошибок отражает распределение предсказаний модели по каждому из классов.\n",
"\n",
"### Выбор ориентира\n",
"\n",
"В качестве базовой модели для оценки качества предсказаний выбрано использование самой распространённой категории целевой переменной `HeartDisease` в обучающей выборке. Этот подход, известный как \"most frequent class baseline\", заключается в том, что модель всегда предсказывает наиболее часто встречающееся значение наличия сердечного приступа.\n",
"\n",
"### Разбиение набора данных на выборки\n",
"\n",
"Выполним разбиение исходного набора на **обучающую** (80%) и **тестовую** (20%) выборки:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" <th>BMI</th>\n",
" <th>Smoking</th>\n",
" <th>AlcoholDrinking</th>\n",
" <th>Stroke</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>DiffWalking</th>\n",
" <th>Sex</th>\n",
" <th>AgeCategory</th>\n",
" <th>Race</th>\n",
" <th>Diabetic</th>\n",
" <th>PhysicalActivity</th>\n",
" <th>GenHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>Asthma</th>\n",
" <th>KidneyDisease</th>\n",
" <th>SkinCancer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>249455</th>\n",
" <td>No</td>\n",
" <td>65.00</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>30-34</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>7.0</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14270</th>\n",
" <td>No</td>\n",
" <td>31.89</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>55-59</td>\n",
" <td>Hispanic</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>163088</th>\n",
" <td>No</td>\n",
" <td>24.41</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>5.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>40-44</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>136626</th>\n",
" <td>Yes</td>\n",
" <td>36.86</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>30.0</td>\n",
" <td>0.0</td>\n",
" <td>Yes</td>\n",
" <td>Male</td>\n",
" <td>65-69</td>\n",
" <td>White</td>\n",
" <td>No, borderline diabetes</td>\n",
" <td>No</td>\n",
" <td>Good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>265773</th>\n",
" <td>No</td>\n",
" <td>35.15</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>2.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>70-74</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>193686</th>\n",
" <td>No</td>\n",
" <td>30.43</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>55-59</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Excellent</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>207316</th>\n",
" <td>No</td>\n",
" <td>33.66</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>55-59</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>229094</th>\n",
" <td>No</td>\n",
" <td>38.95</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>60-64</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>148788</th>\n",
" <td>No</td>\n",
" <td>35.44</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>70-74</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Very good</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35742</th>\n",
" <td>No</td>\n",
" <td>27.26</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>35-39</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>255836 rows × 18 columns</p>\n",
"</div>"
],
"text/plain": [
" HeartDisease BMI Smoking AlcoholDrinking Stroke PhysicalHealth \\\n",
"249455 No 65.00 No No No 3.0 \n",
"14270 No 31.89 No No No 0.0 \n",
"163088 No 24.41 No No No 0.0 \n",
"136626 Yes 36.86 Yes No No 30.0 \n",
"265773 No 35.15 Yes No No 2.0 \n",
"... ... ... ... ... ... ... \n",
"193686 No 30.43 Yes No No 0.0 \n",
"207316 No 33.66 Yes No No 0.0 \n",
"229094 No 38.95 No No No 0.0 \n",
"148788 No 35.44 Yes No No 0.0 \n",
"35742 No 27.26 Yes No No 0.0 \n",
"\n",
" MentalHealth DiffWalking Sex AgeCategory Race \\\n",
"249455 0.0 No Female 30-34 White \n",
"14270 0.0 No Female 55-59 Hispanic \n",
"163088 5.0 No Male 40-44 White \n",
"136626 0.0 Yes Male 65-69 White \n",
"265773 0.0 No Male 70-74 White \n",
"... ... ... ... ... ... \n",
"193686 0.0 No Male 55-59 White \n",
"207316 0.0 No Female 55-59 White \n",
"229094 0.0 No Male 60-64 White \n",
"148788 0.0 No Male 70-74 White \n",
"35742 0.0 No Male 35-39 White \n",
"\n",
" Diabetic PhysicalActivity GenHealth SleepTime Asthma \\\n",
"249455 No Yes Good 7.0 Yes \n",
"14270 No Yes Good 8.0 No \n",
"163088 No Yes Very good 7.0 No \n",
"136626 No, borderline diabetes No Good 8.0 No \n",
"265773 No Yes Good 7.0 No \n",
"... ... ... ... ... ... \n",
"193686 No Yes Excellent 7.0 No \n",
"207316 No Yes Very good 6.0 No \n",
"229094 Yes Yes Good 7.0 No \n",
"148788 No No Very good 6.0 No \n",
"35742 No Yes Very good 8.0 No \n",
"\n",
" KidneyDisease SkinCancer \n",
"249455 No No \n",
"14270 No No \n",
"163088 No No \n",
"136626 No No \n",
"265773 Yes No \n",
"... ... ... \n",
"193686 No No \n",
"207316 No No \n",
"229094 No No \n",
"148788 No No \n",
"35742 No No \n",
"\n",
"[255836 rows x 18 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>249455</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14270</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>163088</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>136626</th>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>265773</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>193686</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>207316</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>229094</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>148788</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35742</th>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>255836 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" HeartDisease\n",
"249455 No\n",
"14270 No\n",
"163088 No\n",
"136626 Yes\n",
"265773 No\n",
"... ...\n",
"193686 No\n",
"207316 No\n",
"229094 No\n",
"148788 No\n",
"35742 No\n",
"\n",
"[255836 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" <th>BMI</th>\n",
" <th>Smoking</th>\n",
" <th>AlcoholDrinking</th>\n",
" <th>Stroke</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>DiffWalking</th>\n",
" <th>Sex</th>\n",
" <th>AgeCategory</th>\n",
" <th>Race</th>\n",
" <th>Diabetic</th>\n",
" <th>PhysicalActivity</th>\n",
" <th>GenHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>Asthma</th>\n",
" <th>KidneyDisease</th>\n",
" <th>SkinCancer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>215485</th>\n",
" <td>No</td>\n",
" <td>32.89</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>10.0</td>\n",
" <td>5.0</td>\n",
" <td>Yes</td>\n",
" <td>Male</td>\n",
" <td>45-49</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Good</td>\n",
" <td>6.0</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>150930</th>\n",
" <td>No</td>\n",
" <td>33.00</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>50-54</td>\n",
" <td>Black</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Fair</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>305511</th>\n",
" <td>No</td>\n",
" <td>39.16</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Yes</td>\n",
" <td>Female</td>\n",
" <td>75-79</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>284576</th>\n",
" <td>No</td>\n",
" <td>28.89</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>65-69</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>170107</th>\n",
" <td>No</td>\n",
" <td>33.96</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>60-64</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>318712</th>\n",
" <td>No</td>\n",
" <td>34.70</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>30.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>25-29</td>\n",
" <td>Hispanic</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Excellent</td>\n",
" <td>8.0</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>169792</th>\n",
" <td>No</td>\n",
" <td>32.61</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>7.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>60-64</td>\n",
" <td>Hispanic</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Fair</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19564</th>\n",
" <td>No</td>\n",
" <td>25.09</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>2.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>60-64</td>\n",
" <td>Black</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>74293</th>\n",
" <td>No</td>\n",
" <td>21.29</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Yes</td>\n",
" <td>Male</td>\n",
" <td>60-64</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>284877</th>\n",
" <td>Yes</td>\n",
" <td>23.30</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>80 or older</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>63959 rows × 18 columns</p>\n",
"</div>"
],
"text/plain": [
" HeartDisease BMI Smoking AlcoholDrinking Stroke PhysicalHealth \\\n",
"215485 No 32.89 No No No 10.0 \n",
"150930 No 33.00 Yes No No 0.0 \n",
"305511 No 39.16 Yes No No 0.0 \n",
"284576 No 28.89 Yes No No 0.0 \n",
"170107 No 33.96 No No No 0.0 \n",
"... ... ... ... ... ... ... \n",
"318712 No 34.70 Yes No No 30.0 \n",
"169792 No 32.61 Yes No No 7.0 \n",
"19564 No 25.09 Yes No No 0.0 \n",
"74293 No 21.29 Yes Yes No 0.0 \n",
"284877 Yes 23.30 No No No 0.0 \n",
"\n",
" MentalHealth DiffWalking Sex AgeCategory Race Diabetic \\\n",
"215485 5.0 Yes Male 45-49 White Yes \n",
"150930 0.0 No Male 50-54 Black No \n",
"305511 0.0 Yes Female 75-79 White Yes \n",
"284576 0.0 No Male 65-69 White Yes \n",
"170107 0.0 No Female 60-64 White No \n",
"... ... ... ... ... ... ... \n",
"318712 0.0 No Male 25-29 Hispanic No \n",
"169792 0.0 No Female 60-64 Hispanic No \n",
"19564 2.0 No Male 60-64 Black No \n",
"74293 0.0 Yes Male 60-64 White Yes \n",
"284877 0.0 No Female 80 or older White No \n",
"\n",
" PhysicalActivity GenHealth SleepTime Asthma KidneyDisease SkinCancer \n",
"215485 No Good 6.0 Yes No No \n",
"150930 No Fair 8.0 No No No \n",
"305511 Yes Very good 7.0 No No No \n",
"284576 Yes Very good 7.0 No No No \n",
"170107 Yes Very good 7.0 No No No \n",
"... ... ... ... ... ... ... \n",
"318712 Yes Excellent 8.0 Yes No No \n",
"169792 Yes Fair 7.0 No No No \n",
"19564 Yes Very good 8.0 No No No \n",
"74293 Yes Good 8.0 No No No \n",
"284877 Yes Good 6.0 No No No \n",
"\n",
"[63959 rows x 18 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>215485</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>150930</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>305511</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>284576</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>170107</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>318712</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>169792</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19564</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>74293</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>284877</th>\n",
" <td>Yes</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>63959 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" HeartDisease\n",
"215485 No\n",
"150930 No\n",
"305511 No\n",
"284576 No\n",
"170107 No\n",
"... ...\n",
"318712 No\n",
"169792 No\n",
"19564 No\n",
"74293 No\n",
"284877 Yes\n",
"\n",
"[63959 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" X = df_input\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ]\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
" \n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname='HeartDisease', frac_train=0.8, frac_val=0, frac_test=0.2, random_state=9\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Построим **базовую модель**, описанную выше, и оценим ее метрики *Accuracy* и *F1-Score*:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Самый частый класс: No\n",
"Baseline Accuracy: 0.9143982864022264\n",
"Baseline F1: 0.8735112563715003\n"
]
}
],
"source": [
"from sklearn.metrics import accuracy_score, f1_score\n",
"\n",
"# Определяем самый частый класс\n",
"most_frequent_class = y_train.mode().values[0][0]\n",
"print(f\"Самый частый класс: {most_frequent_class}\")\n",
"\n",
"# Вычисляем предсказания базовой модели (все предсказания равны самому частому классу)\n",
"baseline_predictions: list[str] = [most_frequent_class] * len(y_test)\n",
"\n",
"# Оцениваем базовую модель\n",
"print('Baseline Accuracy:', accuracy_score(y_test, baseline_predictions))\n",
"print('Baseline F1:', f1_score(y_test, baseline_predictions, average='weighted'))\n",
"\n",
"# Унитарное кодирование для целевого признака\n",
"y_train = y_train['HeartDisease'].map({'Yes': 1, 'No': 0})\n",
"y_test = y_test['HeartDisease'].map({'Yes': 1, 'No': 0})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Выбор моделей обучения\n",
"\n",
"Для обучения были выбраны следующие модели:\n",
"\n",
"1. **Случайный лес** (*Random Forest*): Ансамблевая модель, которая использует множество решающих деревьев. Она хорошо справляется с нелинейными зависимостями и шумом в данных, а также обладает устойчивостью к переобучению.\n",
"2. **Логистическая регрессия** (*Logistic Regression*): Статистический метод для бинарной классификации, который моделирует зависимость между целевой переменной и независимыми признаками, используя логистическую функцию. Она проста в интерпретации и быстра в обучении.\n",
"3. **Метод ближайших соседей** (*KNN*): Алгоритм классификации, который предсказывает класс на основе ближайших k обучающих примеров. KNN интуитивно понятен и не требует обучения, но может быть медленным на больших данных и чувствительным к выбору параметров."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Построение конвейера"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.impute import SimpleImputer\n",
"from sklearn.preprocessing import OneHotEncoder, StandardScaler\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.pipeline import Pipeline\n",
"\n",
"# Разделение признаков на числовые и категориальные\n",
"num_columns = [\n",
" column\n",
" for column in df.columns\n",
" if df[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in df.columns\n",
" if df[column].dtype == \"object\"\n",
"]\n",
"\n",
"# Числовая обработка: заполнение пропусков медианой и стандартизация\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"# Категориальная обработка: заполнение пропусков значением \"unknown\" и кодирование\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"# Общий конвейер обработки признаков\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"# Итоговый конвейер\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Использование конвейера на тренировочных данных"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>BMI</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>HeartDisease_Yes</th>\n",
" <th>Smoking_Yes</th>\n",
" <th>AlcoholDrinking_Yes</th>\n",
" <th>Stroke_Yes</th>\n",
" <th>DiffWalking_Yes</th>\n",
" <th>Sex_Male</th>\n",
" <th>...</th>\n",
" <th>Diabetic_Yes</th>\n",
" <th>Diabetic_Yes (during pregnancy)</th>\n",
" <th>PhysicalActivity_Yes</th>\n",
" <th>GenHealth_Fair</th>\n",
" <th>GenHealth_Good</th>\n",
" <th>GenHealth_Poor</th>\n",
" <th>GenHealth_Very good</th>\n",
" <th>Asthma_Yes</th>\n",
" <th>KidneyDisease_Yes</th>\n",
" <th>SkinCancer_Yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5.773647</td>\n",
" <td>-0.046285</td>\n",
" <td>-0.489254</td>\n",
" <td>-0.065882</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.561206</td>\n",
" <td>-0.424023</td>\n",
" <td>-0.489254</td>\n",
" <td>0.630754</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>-0.616356</td>\n",
" <td>-0.424023</td>\n",
" <td>0.140196</td>\n",
" <td>-0.065882</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.343623</td>\n",
" <td>3.353355</td>\n",
" <td>-0.489254</td>\n",
" <td>0.630754</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1.074421</td>\n",
" <td>-0.172198</td>\n",
" <td>-0.489254</td>\n",
" <td>-0.065882</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>255831</th>\n",
" <td>0.331361</td>\n",
" <td>-0.424023</td>\n",
" <td>-0.489254</td>\n",
" <td>-0.065882</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>255832</th>\n",
" <td>0.839853</td>\n",
" <td>-0.424023</td>\n",
" <td>-0.489254</td>\n",
" <td>-0.762519</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>255833</th>\n",
" <td>1.672648</td>\n",
" <td>-0.424023</td>\n",
" <td>-0.489254</td>\n",
" <td>-0.065882</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>255834</th>\n",
" <td>1.120075</td>\n",
" <td>-0.424023</td>\n",
" <td>-0.489254</td>\n",
" <td>-0.762519</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>255835</th>\n",
" <td>-0.167686</td>\n",
" <td>-0.424023</td>\n",
" <td>-0.489254</td>\n",
" <td>0.630754</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>255836 rows × 38 columns</p>\n",
"</div>"
],
"text/plain": [
" BMI PhysicalHealth MentalHealth SleepTime HeartDisease_Yes \\\n",
"0 5.773647 -0.046285 -0.489254 -0.065882 0.0 \n",
"1 0.561206 -0.424023 -0.489254 0.630754 0.0 \n",
"2 -0.616356 -0.424023 0.140196 -0.065882 0.0 \n",
"3 1.343623 3.353355 -0.489254 0.630754 1.0 \n",
"4 1.074421 -0.172198 -0.489254 -0.065882 0.0 \n",
"... ... ... ... ... ... \n",
"255831 0.331361 -0.424023 -0.489254 -0.065882 0.0 \n",
"255832 0.839853 -0.424023 -0.489254 -0.762519 0.0 \n",
"255833 1.672648 -0.424023 -0.489254 -0.065882 0.0 \n",
"255834 1.120075 -0.424023 -0.489254 -0.762519 0.0 \n",
"255835 -0.167686 -0.424023 -0.489254 0.630754 0.0 \n",
"\n",
" Smoking_Yes AlcoholDrinking_Yes Stroke_Yes DiffWalking_Yes \\\n",
"0 0.0 0.0 0.0 0.0 \n",
"1 0.0 0.0 0.0 0.0 \n",
"2 0.0 0.0 0.0 0.0 \n",
"3 1.0 0.0 0.0 1.0 \n",
"4 1.0 0.0 0.0 0.0 \n",
"... ... ... ... ... \n",
"255831 1.0 0.0 0.0 0.0 \n",
"255832 1.0 0.0 0.0 0.0 \n",
"255833 0.0 0.0 0.0 0.0 \n",
"255834 1.0 0.0 0.0 0.0 \n",
"255835 1.0 0.0 0.0 0.0 \n",
"\n",
" Sex_Male ... Diabetic_Yes Diabetic_Yes (during pregnancy) \\\n",
"0 0.0 ... 0.0 0.0 \n",
"1 0.0 ... 0.0 0.0 \n",
"2 1.0 ... 0.0 0.0 \n",
"3 1.0 ... 0.0 0.0 \n",
"4 1.0 ... 0.0 0.0 \n",
"... ... ... ... ... \n",
"255831 1.0 ... 0.0 0.0 \n",
"255832 0.0 ... 0.0 0.0 \n",
"255833 1.0 ... 1.0 0.0 \n",
"255834 1.0 ... 0.0 0.0 \n",
"255835 1.0 ... 0.0 0.0 \n",
"\n",
" PhysicalActivity_Yes GenHealth_Fair GenHealth_Good GenHealth_Poor \\\n",
"0 1.0 0.0 1.0 0.0 \n",
"1 1.0 0.0 1.0 0.0 \n",
"2 1.0 0.0 0.0 0.0 \n",
"3 0.0 0.0 1.0 0.0 \n",
"4 1.0 0.0 1.0 0.0 \n",
"... ... ... ... ... \n",
"255831 1.0 0.0 0.0 0.0 \n",
"255832 1.0 0.0 0.0 0.0 \n",
"255833 1.0 0.0 1.0 0.0 \n",
"255834 0.0 0.0 0.0 0.0 \n",
"255835 1.0 0.0 0.0 0.0 \n",
"\n",
" GenHealth_Very good Asthma_Yes KidneyDisease_Yes SkinCancer_Yes \n",
"0 0.0 1.0 0.0 0.0 \n",
"1 0.0 0.0 0.0 0.0 \n",
"2 1.0 0.0 0.0 0.0 \n",
"3 0.0 0.0 0.0 0.0 \n",
"4 0.0 0.0 1.0 0.0 \n",
"... ... ... ... ... \n",
"255831 0.0 0.0 0.0 0.0 \n",
"255832 1.0 0.0 0.0 0.0 \n",
"255833 0.0 0.0 0.0 0.0 \n",
"255834 1.0 0.0 0.0 0.0 \n",
"255835 1.0 0.0 0.0 0.0 \n",
"\n",
"[255836 rows x 38 columns]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Обучение моделей"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Модель: RandomForestClassifier\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 1.0\n",
"\tRecall_test: 1.0\n",
"\tAccuracy_train: 1.0\n",
"\tAccuracy_test: 1.0\n",
"\tF1_train: 1.0\n",
"\tF1_test: 1.0\n",
"\tROC_AUC_test: 1.0\n",
"\tCohen_kappa_test: 1.0\n",
"\tMCC_test: 1.0\n",
"\tConfusion_matrix: [[58484 0]\n",
" [ 0 5475]]\n",
"\n",
"Модель: LogisticRegression\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 1.0\n",
"\tRecall_test: 1.0\n",
"\tAccuracy_train: 1.0\n",
"\tAccuracy_test: 1.0\n",
"\tF1_train: 1.0\n",
"\tF1_test: 1.0\n",
"\tROC_AUC_test: 1.0\n",
"\tCohen_kappa_test: 1.0\n",
"\tMCC_test: 1.0\n",
"\tConfusion_matrix: [[58484 0]\n",
" [ 0 5475]]\n",
"\n",
"Модель: KNN\n",
"\tPrecision_train: 0.9985596365852307\n",
"\tPrecision_test: 0.9934322549258088\n",
"\tRecall_train: 0.8231345328340488\n",
"\tRecall_test: 0.7459360730593607\n",
"\tAccuracy_train: 0.9847597679763599\n",
"\tAccuracy_test: 0.9778295470535812\n",
"\tF1_train: 0.9024005607149115\n",
"\tF1_test: 0.8520759440851241\n",
"\tROC_AUC_test: 0.872737204165273\n",
"\tCohen_kappa_test: 0.8403545562876147\n",
"\tMCC_test: 0.8504421479558301\n",
"\tConfusion_matrix: [[58457 27]\n",
" [ 1391 4084]]\n",
"\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn import metrics\n",
"\n",
"\n",
"# Оценка качества различных моделей на основе метрик\n",
"def evaluate_models(models, \n",
" pipeline_end: Pipeline, \n",
" X_train: DataFrame, y_train, \n",
" X_test: DataFrame, y_test):\n",
" results = {}\n",
" \n",
" for model_name, model in models.items():\n",
" # Создание конвейера для текущей модели\n",
" model_pipeline = Pipeline(\n",
" [\n",
" (\"pipeline\", pipeline_end), \n",
" (\"model\", model),\n",
" ]\n",
" )\n",
" \n",
" # Обучение модели\n",
" model_pipeline.fit(X_train, y_train)\n",
" \n",
" # Предсказание для обучающей и тестовой выборки\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_predict = model_pipeline.predict(X_test)\n",
" \n",
" # Вычисление метрик для текущей модели\n",
" metrics_dict = {\n",
" \"Precision_train\": metrics.precision_score(y_train, y_train_predict),\n",
" \"Precision_test\": metrics.precision_score(y_test, y_test_predict),\n",
" \"Recall_train\": metrics.recall_score(y_train, y_train_predict),\n",
" \"Recall_test\": metrics.recall_score(y_test, y_test_predict),\n",
" \"Accuracy_train\": metrics.accuracy_score(y_train, y_train_predict),\n",
" \"Accuracy_test\": metrics.accuracy_score(y_test, y_test_predict),\n",
" \"F1_train\": metrics.f1_score(y_train, y_train_predict),\n",
" \"F1_test\": metrics.f1_score(y_test, y_test_predict),\n",
" \"ROC_AUC_test\": metrics.roc_auc_score(y_test, y_test_predict),\n",
" \"Cohen_kappa_test\": metrics.cohen_kappa_score(y_test, y_test_predict),\n",
" \"MCC_test\": metrics.matthews_corrcoef(y_test, y_test_predict),\n",
" \"Confusion_matrix\": metrics.confusion_matrix(y_test, y_test_predict),\n",
" }\n",
" \n",
" # Сохранение результатов\n",
" results[model_name] = metrics_dict\n",
" \n",
" return results\n",
"\n",
"\n",
"# Выбранные модели для классификации\n",
"models_classification = {\n",
" \"RandomForestClassifier\": RandomForestClassifier(random_state=42),\n",
" \"LogisticRegression\": LogisticRegression(max_iter=1000),\n",
" \"KNN\": KNeighborsClassifier(),\n",
"}\n",
"\n",
"results = evaluate_models(models_classification,\n",
" pipeline_end,\n",
" X_train, y_train,\n",
" X_test, y_test)\n",
"\n",
"# Вывод результатов\n",
"for model_name, metrics_dict in results.items():\n",
" print(f\"Модель: {model_name}\")\n",
" for metric_name, value in metrics_dict.items():\n",
" print(f\"\\t{metric_name}: {value}\")\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Результаты:**\n",
"\n",
"1. **Случайный лес (Random Forest)**:\n",
" - Метрики:\n",
" * Precision (обучение): 1.0\n",
" * Precision (тест): 1.0\n",
" * Recall (обучение): 1.0\n",
" * Recall (тест): 1.0\n",
" * Accuracy (обучение): 1.0\n",
" * Accuracy (тест): 1.0\n",
" * F1 Score (обучение): 1.0\n",
" * F1 Score (тест): 1.0\n",
" * ROC AUC (тест): 1.0\n",
" * Cohen Kappa (тест): 1.0\n",
" * MCC (тест): 1.0\n",
" * Confusion Matrix (тест): \n",
" ```\n",
" [[58484 0]\n",
" [ 0 5475]]\n",
" ```\n",
" - ***Вывод***: модель продемонстрировала идеальные результаты как на обучающей, так и на тестовой выборке. Все метрики (Precision, Recall, Accuracy, F1 Score, ROC AUC и др.) равны 1.0, что свидетельствует о 100%-й точности классификации. Вероятно, модель переобучилась, так как такие результаты практически невозможно достичь на реальных данных. Необходимо проверить данные, наличие утечек информации (например, коррелирующих с целевой переменной признаков), а также параметры модели.\n",
"\n",
"2. **Логистическая регрессия (Logistic Regression)**:\n",
" - Метрики:\n",
" * Precision (обучение): 1.0\n",
" * Precision (тест): 1.0\n",
" * Recall (обучение): 1.0\n",
" * Recall (тест): 1.0\n",
" * Accuracy (обучение): 1.0\n",
" * Accuracy (тест): 1.0\n",
" * F1 Score (обучение): 1.0\n",
" * F1 Score (тест): 1.0\n",
" * ROC AUC (тест): 1.0\n",
" * Cohen Kappa (тест): 1.0\n",
" * MCC (тест): 1.0\n",
" * Confusion Matrix (тест): \n",
" ```\n",
" [[58484 0]\n",
" [ 0 5475]]\n",
" ```\n",
" - ***Вывод***: аналогично Random Forest, результаты выглядят идеальными: все метрики равны 1.0, включая ROC AUC, что предполагает идеальную работу модели. Здесь также есть признаки переобучения или утечки данных. Необходимо пересмотреть подготовку данных и проверить конвейер обработки (особенно этапы предобработки).\n",
"\n",
"3. **Метод ближайших соседей (KNN)**:\n",
" - Метрики:\n",
" * Precision (обучение): 0.999\n",
" * Precision (тест): 0.993\n",
" * Recall (обучение): 0.823\n",
" * Recall (тест): 0.746\n",
" * Accuracy (обучение): 0.985\n",
" * Accuracy (тест): 0.978\n",
" * F1 Score (обучение): 0.902\n",
" * F1 Score (тест): 0.852\n",
" * ROC AUC (тест): 0.872\n",
" * Cohen Kappa (тест): 0.840\n",
" * MCC (тест): 0.850\n",
" * Confusion Matrix (тест): \n",
" ```\n",
" [[58457 27]\n",
" [ 1391 4084]]\n",
" ```\n",
" - ***Вывод***: модель KNN выглядит наиболее реалистичной и стабильной, но уступает случайному лесу и логистической регрессии в точности. Эта модель является хорошей точкой отсчета для дальнейших экспериментов."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Матрица неточностей"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA/QAAAQTCAYAAADKw2LWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAADV0klEQVR4nOzdfXzN9f/H8efZZhdsZ3O1zTIXkYuVXKzSvq5ZlqSEb5FqLruiQqi+lasuREkKUcp8+1ZyURLlIrnIRQopCSGFmIuwITY75/P7w28nx9Bmn7OdzzmP++32uX3zOe+9z/scvnue13m/P++PzTAMQwAAAAAAwFICinsAAAAAAACg4CjoAQAAAACwIAp6AAAAAAAsiIIeAAAAAAALoqAHAAAAAMCCKOgBAAAAALAgCnoAAAAAACwoqLgHAACAmU6fPq3s7GxT+wwODlZoaKipfQIA4C/IZs+hoAcA+IzTp0+rauVwpR90mNpvbGysdu3axQcHAAAKiGz2LAp6AIDPyM7OVvpBh35fX0X2CHOuKss87lTlxN+UnZ3t9x8aAAAoKLLZsyjoAQA+JzzCpvAImyl9OWVOPwAA+DOy2TPYFA8AAAAAAAtihh4A4HMchlMOw7y+AABA4ZDNnkFBDwDwOU4ZcsqcTw1m9QMAgD8jmz2DJfcAAAAAAFgQM/QAAJ/jlFNmLcYzrycAAPwX2ewZFPQAAJ/jMAw5DHOW45nVDwAA/oxs9gyW3AMAAAAAYEEU9PAZ3bp1U5UqVYp7GH7lt99+k81mU1paWrGNoUqVKurWrZvbue3bt6t169aKjIyUzWbTnDlzlJaWJpvNpt9++61YxomilbvxjlkHgKLTvHlzNW/e3LT+LpQTyJ9ly5bJZrNp2bJlxT0U+ACy2TMo6HFZcouj3CMoKEhXXHGFunXrpj/++KO4h1fkunXr5vZ+nHssWLCguIeXx759+zRs2DBt3Ljxom2WLVumDh06KDY2VsHBwYqOjla7du308ccfF91AL1Nqaqo2bdqkF154Qe+9956uu+664h4SAFhSbt6vW7euuIdySatXr9awYcN07NgxU/rL/cI69wgICFCZMmXUpk0brVmzxpTnAAAzcA09CmXEiBGqWrWqTp8+rW+++UZpaWlauXKlfvrpJ4WGhhb38IpUSEiIpkyZkud83bp1i2E0l7Zv3z4NHz5cVapUUb169fI8PnToUI0YMUJXXXWVHnjgAVWuXFl//vmnPv/8c3Xs2FHvv/++7r777qIf+AVs27ZNAQF/fzd56tQprVmzRk8//bT69u3rOn/vvfeqc+fOCgkJKY5hoog5ZcjBrXEAS1q0aFGBf2b16tUaPny4unXrpqioKLfHzs+JgujSpYtuueUWORwO/fLLL5o4caJatGih7777TnXq1LmsPq2kadOmOnXqlIKDg4t7KPABZLNnUNCjUNq0aeOa/ezVq5fKlSunUaNGae7cubrzzjuLeXRFKygoSPfcc49H+v7rr79UsmRJj/R9vlmzZmnEiBHq1KmTPvjgA5UoUcL12KBBg7Rw4UKdOXOmSMaSH+cX6IcOHZKkPB/oAgMDFRgYaNrznjx5UqVKlTKtP5iLe90C1mV28ViYL3IbNGjglu1NmjRRmzZt9Oabb2rixIlmDC/fiiN3AgIC/G6CBp5DNnsGS+5hqiZNmkiSdu7cKUnKzs7WkCFDlJiYqMjISJUqVUpNmjTR0qVL3X4ud2nbK6+8orfeekvVqlVTSEiIrr/+en333Xd5nmfOnDm65pprFBoaqmuuuUaffPLJBcdz8uRJPf7444qPj1dISIhq1qypV155RcZ5O2PabDb17dtXM2fOVEJCgsLCwpSUlKRNmzZJkiZPnqzq1asrNDRUzZs3v+zrsCdOnKirr75aISEhiouLU58+ffIsD2zevLmuueYarV+/Xk2bNlXJkiX1n//8R5KUlZWloUOHqnr16goJCVF8fLwGDx6srKwstz4WL16sxo0bKyoqSuHh4apZs6arj2XLlun666+XJHXv3t21nDD3Ovhnn31WZcqU0bvvvutWzOdKSUnRrbfeetHX+OOPP6pbt2668sorFRoaqtjYWPXo0UN//vmnW7vjx4+rX79+qlKlikJCQhQdHa2bbrpJGzZscLXZvn27OnbsqNjYWIWGhqpixYrq3LmzMjIyXG3OvTZy2LBhqly5sqSzXz7YbDbXvgoXu4b+iy++UJMmTVSqVClFRESobdu22rx5s1ubbt26KTw8XDt37tQtt9yiiIgIde3a9aLvAQD4k++//15t2rSR3W5XeHi4WrVqpW+++SZPux9//FHNmjVTWFiYKlasqOeff15Tp07N87v5QtfQv/HGG7r66qtVsmRJlS5dWtddd50++OADSWd/9w8aNEiSVLVqVVeu5fZ5oWvojx07pv79+7syqGLFirrvvvt0+PDhS77W8z/nnNtfv379XJ83qlevrlGjRsnpdL+11p9//ql7771XdrtdUVFRSk1N1Q8//JBnP5pL5Y7T6dRrr72mq6++WqGhoYqJidEDDzygo0ePuj3XunXrlJKSonLlyiksLExVq1ZVjx493NpMnz5diYmJioiIkN1uV506dTRu3DjX4xe7hn7mzJlKTExUWFiYypUrp3vuuSfPJZe5r+GPP/5Q+/btFR4ervLly2vgwIFyOByXfJ8B5B8z9DBVbniWLl1akpSZmakpU6aoS5cu6t27t44fP6533nlHKSkp+vbbb/Ms9/7ggw90/PhxPfDAA7LZbBo9erQ6dOigX3/91VVcLlq0SB07dlRCQoJGjhypP//8U927d1fFihXd+jIMQ7fddpuWLl2qnj17ql69elq4cKEGDRqkP/74Q2PHjnVr//XXX2vu3Lnq06ePJGnkyJG69dZbNXjwYE2cOFEPP/ywjh49qtGjR6tHjx766quv8rz+8z8IlChRQpGRkZLOfuAYPny4kpOT9dBDD2nbtm1688039d1332nVqlVuxfOff/6pNm3aqHPnzrrnnnsUExMjp9Op2267TStXrtT999+v2rVra9OmTRo7dqx++eUXzZkzR5K0efNm3Xrrrbr22ms1YsQIhYSEaMeOHVq1apUkqXbt2hoxYoSGDBmi+++/3/Xh5F//+pe2b9+urVu3qkePHoqIiMjX3/n5Fi9erF9//VXdu3dXbGysNm/erLfeekubN2/WN998I5vNJkl68MEHNWvWLPXt21cJCQn6888/tXLlSm3ZskUNGjRQdna2UlJSlJWVpUceeUSxsbH6448/NG/ePB07dsz1vp6rQ4cOioqKUv/+/V3LJMPDwy861vfee0+pqalKSUnRqFGj9Ndff+nNN99U48aN9f3337ttspiTk6OUlBQ1btxYr7zySpGtmMDl4dY4QNHYvHmzmjRpIrvdrsGDB6tEiRKaPHmymjdvruXLl6thw4aSpD/++EMtWrSQzWbTU089pVKlSmnKlCn5mj1/++239eijj6pTp0567LHHdPr0af34449au3at7r77bnXo0EG//PKLPvzwQ40dO1blypWTJJUvX/6C/Z04cUJNmjTRli1b1KNHDzVo0ECHDx/W3LlztXfvXtfPX8j5n3Oks6vomjVrpj/++EMPPPCAKlWqpNWrV+upp57S/v379dprr0k6W4i3a9dO3377rR566CHVqlVLn376qVJTUy/4XBfLnQceeEBpaWnq3r27Hn30Ue3atUvjx4/X999/7/o8cfDgQbVu3Vrly5fXk08+qaioKP32229u++AsXrxYXbp0UatWrTRq1ChJ0pYtW7Rq1So99thjF30Pcp/7+uuv18iRI3XgwAGNGzdOq1at0vfff++2Qs7hcCglJUUNGzbUK6+8oi+//FJjxoxRtWrV9NBDD130OeCbyGYPMYDLMHXqVEOS8eWXXxqHDh0y9uzZY8yaNcsoX768ERISYuzZs8cwDMPIyckxsrKy3H726NGjRkxMjNGjRw/XuV27dhmSjLJlyxpHjhxxnf/0008NScZnn33mOlevXj2jQoUKxrFjx1znFi1aZEgyKleu7Do3Z84cQ5Lx/PPPuz1/p06
"text/plain": [
"<Figure size 1200x1000 with 7 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"from math import ceil\n",
"\n",
"_, ax = plt.subplots(ceil(len(models_classification) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"\n",
"for index, key in enumerate(models_classification.keys()):\n",
" c_matrix = results[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Yes\", \"No\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Подбор гиперпараметров"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Oleg\\Desktop\\AIM_ForLab4\\lab_4\\aimenv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие параметры: {'model__criterion': 'gini', 'model__max_depth': 5, 'model__max_features': 'sqrt', 'model__n_estimators': 100}\n"
]
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"# Создание конвейера\n",
"pipeline = Pipeline([\n",
" (\"processing\", pipeline_end),\n",
" (\"model\", RandomForestClassifier(random_state=42))\n",
"])\n",
"\n",
"# Установка параметров для поиска по сетке\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 50, 100],\n",
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
" \"model__max_depth\": [5, 7, 10],\n",
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью поиска по сетке\n",
"grid_search = GridSearchCV(estimator=pipeline, \n",
" param_grid=param_grid,\n",
" n_jobs=-1)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"grid_search.fit(X_train, y_train)\n",
"\n",
"# Результаты подбора гиперпараметров\n",
"print(\"Лучшие параметры:\", grid_search.best_params_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Сравнение наборов гиперпараметров"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Стоковая модель:\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 1.0\n",
"\tRecall_test: 1.0\n",
"\tAccuracy_train: 1.0\n",
"\tAccuracy_test: 1.0\n",
"\tF1_train: 1.0\n",
"\tF1_test: 1.0\n",
"\tROC_AUC_test: 1.0\n",
"\tCohen_kappa_test: 1.0\n",
"\tMCC_test: 1.0\n",
"\tConfusion_matrix: [[58484 0]\n",
" [ 0 5475]]\n",
"\n",
"Оптимизированная модель:\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 0.9995433372910768\n",
"\tRecall_test: 0.9994520547945206\n",
"\tAccuracy_train: 0.9999609124595444\n",
"\tAccuracy_test: 0.9999530949514532\n",
"\tF1_train: 0.9997716164984242\n",
"\tF1_test: 0.9997259523157029\n",
"\tROC_AUC_test: 0.9997260273972604\n",
"\tCohen_kappa_test: 0.9997003049351247\n",
"\tMCC_test: 0.9997003498302348\n",
"\tConfusion_matrix: [[58484 0]\n",
" [ 3 5472]]\n"
]
}
],
"source": [
"# Обучение модели со старыми гипермараметрами\n",
"pipeline.fit(X_train, y_train)\n",
"\n",
"# Предсказание для обучающей и тестовой выборки\n",
"y_train_predict = pipeline.predict(X_train)\n",
"y_test_predict = pipeline.predict(X_test)\n",
" \n",
"# Вычисление метрик для модели со старыми гипермараметрами\n",
"base_model_metrics = {\n",
" \"Precision_train\": metrics.precision_score(y_train, y_train_predict),\n",
" \"Precision_test\": metrics.precision_score(y_test, y_test_predict),\n",
" \"Recall_train\": metrics.recall_score(y_train, y_train_predict),\n",
" \"Recall_test\": metrics.recall_score(y_test, y_test_predict),\n",
" \"Accuracy_train\": metrics.accuracy_score(y_train, y_train_predict),\n",
" \"Accuracy_test\": metrics.accuracy_score(y_test, y_test_predict),\n",
" \"F1_train\": metrics.f1_score(y_train, y_train_predict),\n",
" \"F1_test\": metrics.f1_score(y_test, y_test_predict),\n",
" \"ROC_AUC_test\": metrics.roc_auc_score(y_test, y_test_predict),\n",
" \"Cohen_kappa_test\": metrics.cohen_kappa_score(y_test, y_test_predict),\n",
" \"MCC_test\": metrics.matthews_corrcoef(y_test, y_test_predict),\n",
" \"Confusion_matrix\": metrics.confusion_matrix(y_test, y_test_predict),\n",
"}\n",
"\n",
"# Модель с новыми гипермараметрами\n",
"optimized_model = RandomForestClassifier(\n",
" random_state=42,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"sqrt\",\n",
" n_estimators=10,\n",
")\n",
"\n",
"# Создание конвейера для модели с новыми гипермараметрами\n",
"optimized_model_pipeline = Pipeline(\n",
" [\n",
" (\"pipeline\", pipeline_end), \n",
" (\"model\", optimized_model),\n",
" ]\n",
")\n",
" \n",
"# Обучение модели с новыми гипермараметрами\n",
"optimized_model_pipeline.fit(X_train, y_train)\n",
" \n",
"# Предсказание для обучающей и тестовой выборки\n",
"y_train_predict = optimized_model_pipeline.predict(X_train)\n",
"y_test_predict = optimized_model_pipeline.predict(X_test)\n",
" \n",
"# Вычисление метрик для модели с новыми гипермараметрами\n",
"optimized_model_metrics = {\n",
" \"Precision_train\": metrics.precision_score(y_train, y_train_predict),\n",
" \"Precision_test\": metrics.precision_score(y_test, y_test_predict),\n",
" \"Recall_train\": metrics.recall_score(y_train, y_train_predict),\n",
" \"Recall_test\": metrics.recall_score(y_test, y_test_predict),\n",
" \"Accuracy_train\": metrics.accuracy_score(y_train, y_train_predict),\n",
" \"Accuracy_test\": metrics.accuracy_score(y_test, y_test_predict),\n",
" \"F1_train\": metrics.f1_score(y_train, y_train_predict),\n",
" \"F1_test\": metrics.f1_score(y_test, y_test_predict),\n",
" \"ROC_AUC_test\": metrics.roc_auc_score(y_test, y_test_predict),\n",
" \"Cohen_kappa_test\": metrics.cohen_kappa_score(y_test, y_test_predict),\n",
" \"MCC_test\": metrics.matthews_corrcoef(y_test, y_test_predict),\n",
" \"Confusion_matrix\": metrics.confusion_matrix(y_test, y_test_predict),\n",
"}\n",
"\n",
"# Вывод информации\n",
"print('Стоковая модель:')\n",
"for metric_name, value in base_model_metrics.items():\n",
" print(f\"\\t{metric_name}: {value}\")\n",
"\n",
"print('\\nОптимизированная модель:')\n",
"for metric_name, value in optimized_model_metrics.items():\n",
" print(f\"\\t{metric_name}: {value}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "aimenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}