Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная...
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная.
— Возьмём?
— Конечно!
Она сзади села и в спальник забралась.
Водила сразу:
— Ну–ка, порули пока, я сейчас... — и к ней в спальник.
Через 10 минут выбрался, сел вперёд, закурил:
— Ай, хороша девка!
Напарник:
— Чё, правда хороша? Ну–ка, порули пока... — и к ней в спальник.
Через 10 минут девка вылезает, садится с водилой, закуривает:
— Ай, хорош у тебя напарник!
Водила:
— Чё правда так хорош? Ну–ка, порули...
2024-12-14 00:24:39 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Лабораторная работа №4. Обучение с учителем.\n",
"\n",
"## Датасет \"Н а б о р данных для анализа и прогнозирования сердечного приступа\".\n",
"\n",
"[**Ссылка**](https://www.kaggle.com/datasets/kamilpytlak/personal-key-indicators-of-heart-disease)\n",
"\n",
"### Описание датасета\n",
"\n",
"**Проблемная область**: Датасет связан с медицинской статистикой и направлен на анализ факторов, связанных с риском сердечного приступа. Это важно для прогнозирования и разработки стратегий профилактики сердечно-сосудистых заболеваний.\n",
"\n",
"**Актуальность**: Сердечно-сосудистые заболевания являются одной из ведущих причин смертности во всем мире. Анализ данных о б образе жизни, состоянии здоровья и наследственных факторах позволяет выделить ключевые предикторы, влияющие на развитие сердечно-сосудистых заболеваний. Этот датасет предоставляет инструменты для анализа таких факторов и может быть полезен в создании прогнозных моделей, направленных на снижение рисков и своевременную диагностику.\n",
"\n",
"**Объекты наблюдения**: Каждая запись представляет собой данные о человеке, включая информацию о б их состоянии здоровья, образе жизни, демографических характеристиках и наличию определенных заболеваний. Объекты наблюдений — это индивидуальные пациенты.\n",
"\n",
"**Атрибуты объектов:**\n",
"- `HeartDisease` — наличие сердечного приступа (Yes/No) (целевая переменная).\n",
"- `BMI` — индекс массы тела (Body Mass Index), числовой показатель.\n",
"- `Smoking` — курение (Yes/No).\n",
"- `AlcoholDrinking` — употребление алкоголя (Yes/No).\n",
"- `Stroke` — наличие инсульта (Yes/No).\n",
"- `PhysicalHealth` — количество дней в месяц, когда физическое здоровье было неудовлетворительным.\n",
"- `MentalHealth` — количество дней в месяц, когда психическое здоровье было неудовлетворительным.\n",
"- `DiffWalking` — трудности при ходьбе (Yes/No).\n",
"- `Sex` — пол (Male/Female).\n",
"- `AgeCategory` — возрастная категория (например, 55-59, 80 or older).\n",
"- `Race` — расовая принадлежность (например, White, Black).\n",
"- `Diabetic` — наличие диабета (Yes/No/No, borderline diabetes).\n",
"- `PhysicalActivity` — физическая активность (Yes/No).\n",
"- `GenHealth` — общее состояние здоровья (от Excellent до Poor).\n",
"- `SleepTime` — среднее количество часов сна за сутки.\n",
"- `Asthma` — наличие астмы (Yes/No).\n",
"- `KidneyDisease` — наличие заболеваний почек (Yes/No).\n",
"- `SkinCancer` — наличие кожного рака (Yes/No)."
]
},
{
"cell_type": "code",
2024-12-14 11:55:47 +04:00
"execution_count": 2,
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная...
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная.
— Возьмём?
— Конечно!
Она сзади села и в спальник забралась.
Водила сразу:
— Ну–ка, порули пока, я сейчас... — и к ней в спальник.
Через 10 минут выбрался, сел вперёд, закурил:
— Ай, хороша девка!
Напарник:
— Чё, правда хороша? Ну–ка, порули пока... — и к ней в спальник.
Через 10 минут девка вылезает, садится с водилой, закуривает:
— Ай, хорош у тебя напарник!
Водила:
— Чё правда так хорош? Ну–ка, порули...
2024-12-14 00:24:39 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" <th>BMI</th>\n",
" <th>Smoking</th>\n",
" <th>AlcoholDrinking</th>\n",
" <th>Stroke</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>DiffWalking</th>\n",
" <th>Sex</th>\n",
" <th>AgeCategory</th>\n",
" <th>Race</th>\n",
" <th>Diabetic</th>\n",
" <th>PhysicalActivity</th>\n",
" <th>GenHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>Asthma</th>\n",
" <th>KidneyDisease</th>\n",
" <th>SkinCancer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>No</td>\n",
" <td>16.60</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>3.0</td>\n",
" <td>30.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>55-59</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>5.0</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>No</td>\n",
" <td>20.34</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>80 or older</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>No</td>\n",
" <td>26.58</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>20.0</td>\n",
" <td>30.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>65-69</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Fair</td>\n",
" <td>8.0</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>No</td>\n",
" <td>24.21</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>75-79</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Good</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>No</td>\n",
" <td>23.71</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>28.0</td>\n",
" <td>0.0</td>\n",
" <td>Yes</td>\n",
" <td>Female</td>\n",
" <td>40-44</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" HeartDisease BMI Smoking AlcoholDrinking Stroke PhysicalHealth \\\n",
"0 No 16.60 Yes No No 3.0 \n",
"1 No 20.34 No No Yes 0.0 \n",
"2 No 26.58 Yes No No 20.0 \n",
"3 No 24.21 No No No 0.0 \n",
"4 No 23.71 No No No 28.0 \n",
"\n",
" MentalHealth DiffWalking Sex AgeCategory Race Diabetic \\\n",
"0 30.0 No Female 55-59 White Yes \n",
"1 0.0 No Female 80 or older White No \n",
"2 30.0 No Male 65-69 White Yes \n",
"3 0.0 No Female 75-79 White No \n",
"4 0.0 Yes Female 40-44 White No \n",
"\n",
" PhysicalActivity GenHealth SleepTime Asthma KidneyDisease SkinCancer \n",
"0 Yes Very good 5.0 Yes No Yes \n",
"1 Yes Very good 7.0 No No No \n",
"2 Yes Fair 8.0 Yes No No \n",
"3 No Good 6.0 No No Yes \n",
"4 Yes Very good 8.0 No No No "
]
},
2024-12-14 11:55:47 +04:00
"execution_count": 2,
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная...
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная.
— Возьмём?
— Конечно!
Она сзади села и в спальник забралась.
Водила сразу:
— Ну–ка, порули пока, я сейчас... — и к ней в спальник.
Через 10 минут выбрался, сел вперёд, закурил:
— Ай, хороша девка!
Напарник:
— Чё, правда хороша? Ну–ка, порули пока... — и к ней в спальник.
Через 10 минут девка вылезает, садится с водилой, закуривает:
— Ай, хорош у тебя напарник!
Водила:
— Чё правда так хорош? Ну–ка, порули...
2024-12-14 00:24:39 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"df = pd.read_csv(\".//static//csv//heart_2020_cleaned.csv\")\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Бизнес-цель №1. Задача классификации\n",
"\n",
"### Описание бизнес-цели\n",
"\n",
"**Цель**: предсказание наличия сердечного приступа. Цель состоит в разработке модели, которая будет предсказывать, возникнет ли у человека сердечный приступ (признак `HeartDisease`). Это важная задача для профилактики сердечно-сосудистых заболеваний, позволяющая выявить группы риска и назначить своевременное лечение или профилактические меры.\n",
"\n",
"### Достижимый уровень качества модели\n",
"\n",
"**Основные метрики для классификации:**\n",
"\n",
"- **Accuracy** (*точность*) – показывает долю правильно классифицированных примеров среди всех наблюдений. Легко интерпретируется, но может быть недостаточно информативной для несбалансированных классов.\n",
"- **F1-Score** – гармоническое среднее между точностью (precision) и полнотой (recall). Подходит для задач, где важно одновременно учитывать как ложные положительные, так и ложные отрицательные ошибки, особенно при несбалансированных классах.\n",
"- **ROC AUC** (*Area Under the ROC Curve*) – отражает способность модели различать положительные и отрицательные классы на всех уровнях порога вероятности. Значение от 0.5 (случайное угадывание) до 1.0 (идеальная модель). Полезна для оценки модели на несбалансированных данных.\n",
"- **Cohen's Kappa** – измеряет степень согласия между предсказаниями модели и истинными метками с учётом случайного угадывания. Значения варьируются от -1 (полное несогласие) до 1 (идеальное согласие). Удобна для оценки на несбалансированных данных.\n",
"- **MCC** (*Matthews Correlation Coefficient*) – метрика корреляции между предсказаниями и истинными классами, учитывающая все типы ошибок (TP, TN, FP, FN). Значение варьируется от -1 (полная несоответствие) до 1 (идеальное совпадение). Отлично подходит для задач с несбалансированными классами.\n",
"- **Confusion Matrix** (*матрица ошибок*) – матрица ошибок отражает распределение предсказаний модели по каждому из классов.\n",
"\n",
"### Выбор ориентира\n",
"\n",
"В качестве базовой модели для оценки качества предсказаний выбрано использование самой распространённой категории целевой переменной `HeartDisease` в обучающей выборке. Этот подход, известный как \"most frequent class baseline\", заключается в том, что модель всегда предсказывает наиболее часто встречающееся значение наличия сердечного приступа.\n",
"\n",
"### Разбиение набора данных на выборки\n",
"\n",
"Выполним разбиение исходного набора на **обучающую** (80%) и **тестовую** (20%) выборки:"
]
},
{
"cell_type": "code",
2024-12-14 11:55:47 +04:00
"execution_count": 3,
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная...
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная.
— Возьмём?
— Конечно!
Она сзади села и в спальник забралась.
Водила сразу:
— Ну–ка, порули пока, я сейчас... — и к ней в спальник.
Через 10 минут выбрался, сел вперёд, закурил:
— Ай, хороша девка!
Напарник:
— Чё, правда хороша? Ну–ка, порули пока... — и к ней в спальник.
Через 10 минут девка вылезает, садится с водилой, закуривает:
— Ай, хорош у тебя напарник!
Водила:
— Чё правда так хорош? Ну–ка, порули...
2024-12-14 00:24:39 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" <th>BMI</th>\n",
" <th>Smoking</th>\n",
" <th>AlcoholDrinking</th>\n",
" <th>Stroke</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>DiffWalking</th>\n",
" <th>Sex</th>\n",
" <th>AgeCategory</th>\n",
" <th>Race</th>\n",
" <th>Diabetic</th>\n",
" <th>PhysicalActivity</th>\n",
" <th>GenHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>Asthma</th>\n",
" <th>KidneyDisease</th>\n",
" <th>SkinCancer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>249455</th>\n",
" <td>No</td>\n",
" <td>65.00</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>30-34</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>7.0</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14270</th>\n",
" <td>No</td>\n",
" <td>31.89</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>55-59</td>\n",
" <td>Hispanic</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>163088</th>\n",
" <td>No</td>\n",
" <td>24.41</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>5.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>40-44</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>136626</th>\n",
" <td>Yes</td>\n",
" <td>36.86</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>30.0</td>\n",
" <td>0.0</td>\n",
" <td>Yes</td>\n",
" <td>Male</td>\n",
" <td>65-69</td>\n",
" <td>White</td>\n",
" <td>No, borderline diabetes</td>\n",
" <td>No</td>\n",
" <td>Good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>265773</th>\n",
" <td>No</td>\n",
" <td>35.15</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>2.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>70-74</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>193686</th>\n",
" <td>No</td>\n",
" <td>30.43</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>55-59</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Excellent</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>207316</th>\n",
" <td>No</td>\n",
" <td>33.66</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>55-59</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>229094</th>\n",
" <td>No</td>\n",
" <td>38.95</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>60-64</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>148788</th>\n",
" <td>No</td>\n",
" <td>35.44</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>70-74</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Very good</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35742</th>\n",
" <td>No</td>\n",
" <td>27.26</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>35-39</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>255836 rows × 18 columns</p>\n",
"</div>"
],
"text/plain": [
" HeartDisease BMI Smoking AlcoholDrinking Stroke PhysicalHealth \\\n",
"249455 No 65.00 No No No 3.0 \n",
"14270 No 31.89 No No No 0.0 \n",
"163088 No 24.41 No No No 0.0 \n",
"136626 Yes 36.86 Yes No No 30.0 \n",
"265773 No 35.15 Yes No No 2.0 \n",
"... ... ... ... ... ... ... \n",
"193686 No 30.43 Yes No No 0.0 \n",
"207316 No 33.66 Yes No No 0.0 \n",
"229094 No 38.95 No No No 0.0 \n",
"148788 No 35.44 Yes No No 0.0 \n",
"35742 No 27.26 Yes No No 0.0 \n",
"\n",
" MentalHealth DiffWalking Sex AgeCategory Race \\\n",
"249455 0.0 No Female 30-34 White \n",
"14270 0.0 No Female 55-59 Hispanic \n",
"163088 5.0 No Male 40-44 White \n",
"136626 0.0 Yes Male 65-69 White \n",
"265773 0.0 No Male 70-74 White \n",
"... ... ... ... ... ... \n",
"193686 0.0 No Male 55-59 White \n",
"207316 0.0 No Female 55-59 White \n",
"229094 0.0 No Male 60-64 White \n",
"148788 0.0 No Male 70-74 White \n",
"35742 0.0 No Male 35-39 White \n",
"\n",
" Diabetic PhysicalActivity GenHealth SleepTime Asthma \\\n",
"249455 No Yes Good 7.0 Yes \n",
"14270 No Yes Good 8.0 No \n",
"163088 No Yes Very good 7.0 No \n",
"136626 No, borderline diabetes No Good 8.0 No \n",
"265773 No Yes Good 7.0 No \n",
"... ... ... ... ... ... \n",
"193686 No Yes Excellent 7.0 No \n",
"207316 No Yes Very good 6.0 No \n",
"229094 Yes Yes Good 7.0 No \n",
"148788 No No Very good 6.0 No \n",
"35742 No Yes Very good 8.0 No \n",
"\n",
" KidneyDisease SkinCancer \n",
"249455 No No \n",
"14270 No No \n",
"163088 No No \n",
"136626 No No \n",
"265773 Yes No \n",
"... ... ... \n",
"193686 No No \n",
"207316 No No \n",
"229094 No No \n",
"148788 No No \n",
"35742 No No \n",
"\n",
"[255836 rows x 18 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>249455</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14270</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>163088</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>136626</th>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>265773</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>193686</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>207316</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>229094</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>148788</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35742</th>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>255836 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" HeartDisease\n",
"249455 No\n",
"14270 No\n",
"163088 No\n",
"136626 Yes\n",
"265773 No\n",
"... ...\n",
"193686 No\n",
"207316 No\n",
"229094 No\n",
"148788 No\n",
"35742 No\n",
"\n",
"[255836 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" <th>BMI</th>\n",
" <th>Smoking</th>\n",
" <th>AlcoholDrinking</th>\n",
" <th>Stroke</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>DiffWalking</th>\n",
" <th>Sex</th>\n",
" <th>AgeCategory</th>\n",
" <th>Race</th>\n",
" <th>Diabetic</th>\n",
" <th>PhysicalActivity</th>\n",
" <th>GenHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>Asthma</th>\n",
" <th>KidneyDisease</th>\n",
" <th>SkinCancer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>215485</th>\n",
" <td>No</td>\n",
" <td>32.89</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>10.0</td>\n",
" <td>5.0</td>\n",
" <td>Yes</td>\n",
" <td>Male</td>\n",
" <td>45-49</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Good</td>\n",
" <td>6.0</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>150930</th>\n",
" <td>No</td>\n",
" <td>33.00</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>50-54</td>\n",
" <td>Black</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Fair</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>305511</th>\n",
" <td>No</td>\n",
" <td>39.16</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Yes</td>\n",
" <td>Female</td>\n",
" <td>75-79</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>284576</th>\n",
" <td>No</td>\n",
" <td>28.89</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>65-69</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>170107</th>\n",
" <td>No</td>\n",
" <td>33.96</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>60-64</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>318712</th>\n",
" <td>No</td>\n",
" <td>34.70</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>30.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>25-29</td>\n",
" <td>Hispanic</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Excellent</td>\n",
" <td>8.0</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>169792</th>\n",
" <td>No</td>\n",
" <td>32.61</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>7.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>60-64</td>\n",
" <td>Hispanic</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Fair</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19564</th>\n",
" <td>No</td>\n",
" <td>25.09</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>2.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>60-64</td>\n",
" <td>Black</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>74293</th>\n",
" <td>No</td>\n",
" <td>21.29</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Yes</td>\n",
" <td>Male</td>\n",
" <td>60-64</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>284877</th>\n",
" <td>Yes</td>\n",
" <td>23.30</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>80 or older</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>63959 rows × 18 columns</p>\n",
"</div>"
],
"text/plain": [
" HeartDisease BMI Smoking AlcoholDrinking Stroke PhysicalHealth \\\n",
"215485 No 32.89 No No No 10.0 \n",
"150930 No 33.00 Yes No No 0.0 \n",
"305511 No 39.16 Yes No No 0.0 \n",
"284576 No 28.89 Yes No No 0.0 \n",
"170107 No 33.96 No No No 0.0 \n",
"... ... ... ... ... ... ... \n",
"318712 No 34.70 Yes No No 30.0 \n",
"169792 No 32.61 Yes No No 7.0 \n",
"19564 No 25.09 Yes No No 0.0 \n",
"74293 No 21.29 Yes Yes No 0.0 \n",
"284877 Yes 23.30 No No No 0.0 \n",
"\n",
" MentalHealth DiffWalking Sex AgeCategory Race Diabetic \\\n",
"215485 5.0 Yes Male 45-49 White Yes \n",
"150930 0.0 No Male 50-54 Black No \n",
"305511 0.0 Yes Female 75-79 White Yes \n",
"284576 0.0 No Male 65-69 White Yes \n",
"170107 0.0 No Female 60-64 White No \n",
"... ... ... ... ... ... ... \n",
"318712 0.0 No Male 25-29 Hispanic No \n",
"169792 0.0 No Female 60-64 Hispanic No \n",
"19564 2.0 No Male 60-64 Black No \n",
"74293 0.0 Yes Male 60-64 White Yes \n",
"284877 0.0 No Female 80 or older White No \n",
"\n",
" PhysicalActivity GenHealth SleepTime Asthma KidneyDisease SkinCancer \n",
"215485 No Good 6.0 Yes No No \n",
"150930 No Fair 8.0 No No No \n",
"305511 Yes Very good 7.0 No No No \n",
"284576 Yes Very good 7.0 No No No \n",
"170107 Yes Very good 7.0 No No No \n",
"... ... ... ... ... ... ... \n",
"318712 Yes Excellent 8.0 Yes No No \n",
"169792 Yes Fair 7.0 No No No \n",
"19564 Yes Very good 8.0 No No No \n",
"74293 Yes Good 8.0 No No No \n",
"284877 Yes Good 6.0 No No No \n",
"\n",
"[63959 rows x 18 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>215485</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>150930</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>305511</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>284576</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>170107</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>318712</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>169792</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19564</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>74293</th>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>284877</th>\n",
" <td>Yes</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>63959 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" HeartDisease\n",
"215485 No\n",
"150930 No\n",
"305511 No\n",
"284576 No\n",
"170107 No\n",
"... ...\n",
"318712 No\n",
"169792 No\n",
"19564 No\n",
"74293 No\n",
"284877 Yes\n",
"\n",
"[63959 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" X = df_input\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ]\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
" \n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname='HeartDisease', frac_train=0.8, frac_val=0, frac_test=0.2, random_state=9\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Построим **базовую модель**, описанную выше, и оценим е е метрики *Accuracy* и *F1-Score*:"
]
},
{
"cell_type": "code",
2024-12-14 11:55:47 +04:00
"execution_count": 5,
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная...
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная.
— Возьмём?
— Конечно!
Она сзади села и в спальник забралась.
Водила сразу:
— Ну–ка, порули пока, я сейчас... — и к ней в спальник.
Через 10 минут выбрался, сел вперёд, закурил:
— Ай, хороша девка!
Напарник:
— Чё, правда хороша? Ну–ка, порули пока... — и к ней в спальник.
Через 10 минут девка вылезает, садится с водилой, закуривает:
— Ай, хорош у тебя напарник!
Водила:
— Чё правда так хорош? Ну–ка, порули...
2024-12-14 00:24:39 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Самый частый класс: No\n",
"Baseline Accuracy: 0.9143982864022264\n",
"Baseline F1: 0.8735112563715003\n"
]
}
],
"source": [
"from sklearn.metrics import accuracy_score, f1_score\n",
"\n",
"# Определяем самый частый класс\n",
"most_frequent_class = y_train.mode().values[0][0]\n",
"print(f\"Самый частый класс: {most_frequent_class}\")\n",
"\n",
"# Вычисляем предсказания базовой модели (все предсказания равны самому частому классу)\n",
"baseline_predictions: list[str] = [most_frequent_class] * len(y_test)\n",
"\n",
"# Оцениваем базовую модель\n",
"print('Baseline Accuracy:', accuracy_score(y_test, baseline_predictions))\n",
"print('Baseline F1:', f1_score(y_test, baseline_predictions, average='weighted'))\n",
"\n",
"# Унитарное кодирование для целевого признака\n",
"y_train = y_train['HeartDisease'].map({'Yes': 1, 'No': 0})\n",
"y_test = y_test['HeartDisease'].map({'Yes': 1, 'No': 0})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Выбор моделей обучения\n",
"\n",
"Для обучения были выбраны следующие модели:\n",
"\n",
"1. **Случайный лес** (*Random Forest*): Ансамблевая модель, которая использует множество решающих деревьев. Она хорошо справляется с нелинейными зависимостями и шумом в данных, а также обладает устойчивостью к переобучению.\n",
"2. **Логистическая регрессия** (*Logistic Regression*): Статистический метод для бинарной классификации, который моделирует зависимость между целевой переменной и независимыми признаками, используя логистическую функцию. Она проста в интерпретации и быстра в обучении.\n",
"3. **Метод ближайших соседей** (*KNN*): Алгоритм классификации, который предсказывает класс на основе ближайших k обучающих примеров. KNN интуитивно понятен и не требует обучения, но может быть медленным на больших данных и чувствительным к выбору параметров."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Построение конвейера"
]
},
{
"cell_type": "code",
2024-12-14 11:55:47 +04:00
"execution_count": 4,
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная...
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная.
— Возьмём?
— Конечно!
Она сзади села и в спальник забралась.
Водила сразу:
— Ну–ка, порули пока, я сейчас... — и к ней в спальник.
Через 10 минут выбрался, сел вперёд, закурил:
— Ай, хороша девка!
Напарник:
— Чё, правда хороша? Ну–ка, порули пока... — и к ней в спальник.
Через 10 минут девка вылезает, садится с водилой, закуривает:
— Ай, хорош у тебя напарник!
Водила:
— Чё правда так хорош? Ну–ка, порули...
2024-12-14 00:24:39 +04:00
"metadata": {},
"outputs": [],
"source": [
"from sklearn.impute import SimpleImputer\n",
"from sklearn.preprocessing import OneHotEncoder, StandardScaler\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.pipeline import Pipeline\n",
"\n",
"# Разделение признаков на числовые и категориальные\n",
"num_columns = [\n",
" column\n",
" for column in df.columns\n",
" if df[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in df.columns\n",
" if df[column].dtype == \"object\"\n",
"]\n",
"\n",
"# Числовая обработка: заполнение пропусков медианой и стандартизация\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"# Категориальная обработка: заполнение пропусков значением \"unknown\" и кодирование\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"# Общий конвейер обработки признаков\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"# Итоговый конвейер\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Использование конвейера на тренировочных данных"
]
},
{
"cell_type": "code",
2024-12-14 11:55:47 +04:00
"execution_count": 5,
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная...
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная.
— Возьмём?
— Конечно!
Она сзади села и в спальник забралась.
Водила сразу:
— Ну–ка, порули пока, я сейчас... — и к ней в спальник.
Через 10 минут выбрался, сел вперёд, закурил:
— Ай, хороша девка!
Напарник:
— Чё, правда хороша? Ну–ка, порули пока... — и к ней в спальник.
Через 10 минут девка вылезает, садится с водилой, закуривает:
— Ай, хорош у тебя напарник!
Водила:
— Чё правда так хорош? Ну–ка, порули...
2024-12-14 00:24:39 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>BMI</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>HeartDisease_Yes</th>\n",
" <th>Smoking_Yes</th>\n",
" <th>AlcoholDrinking_Yes</th>\n",
" <th>Stroke_Yes</th>\n",
" <th>DiffWalking_Yes</th>\n",
" <th>Sex_Male</th>\n",
" <th>...</th>\n",
" <th>Diabetic_Yes</th>\n",
" <th>Diabetic_Yes (during pregnancy)</th>\n",
" <th>PhysicalActivity_Yes</th>\n",
" <th>GenHealth_Fair</th>\n",
" <th>GenHealth_Good</th>\n",
" <th>GenHealth_Poor</th>\n",
" <th>GenHealth_Very good</th>\n",
" <th>Asthma_Yes</th>\n",
" <th>KidneyDisease_Yes</th>\n",
" <th>SkinCancer_Yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5.773647</td>\n",
" <td>-0.046285</td>\n",
" <td>-0.489254</td>\n",
" <td>-0.065882</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.561206</td>\n",
" <td>-0.424023</td>\n",
" <td>-0.489254</td>\n",
" <td>0.630754</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>-0.616356</td>\n",
" <td>-0.424023</td>\n",
" <td>0.140196</td>\n",
" <td>-0.065882</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.343623</td>\n",
" <td>3.353355</td>\n",
" <td>-0.489254</td>\n",
" <td>0.630754</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1.074421</td>\n",
" <td>-0.172198</td>\n",
" <td>-0.489254</td>\n",
" <td>-0.065882</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>255831</th>\n",
" <td>0.331361</td>\n",
" <td>-0.424023</td>\n",
" <td>-0.489254</td>\n",
" <td>-0.065882</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>255832</th>\n",
" <td>0.839853</td>\n",
" <td>-0.424023</td>\n",
" <td>-0.489254</td>\n",
" <td>-0.762519</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>255833</th>\n",
" <td>1.672648</td>\n",
" <td>-0.424023</td>\n",
" <td>-0.489254</td>\n",
" <td>-0.065882</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>255834</th>\n",
" <td>1.120075</td>\n",
" <td>-0.424023</td>\n",
" <td>-0.489254</td>\n",
" <td>-0.762519</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>255835</th>\n",
" <td>-0.167686</td>\n",
" <td>-0.424023</td>\n",
" <td>-0.489254</td>\n",
" <td>0.630754</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>255836 rows × 38 columns</p>\n",
"</div>"
],
"text/plain": [
" BMI PhysicalHealth MentalHealth SleepTime HeartDisease_Yes \\\n",
"0 5.773647 -0.046285 -0.489254 -0.065882 0.0 \n",
"1 0.561206 -0.424023 -0.489254 0.630754 0.0 \n",
"2 -0.616356 -0.424023 0.140196 -0.065882 0.0 \n",
"3 1.343623 3.353355 -0.489254 0.630754 1.0 \n",
"4 1.074421 -0.172198 -0.489254 -0.065882 0.0 \n",
"... ... ... ... ... ... \n",
"255831 0.331361 -0.424023 -0.489254 -0.065882 0.0 \n",
"255832 0.839853 -0.424023 -0.489254 -0.762519 0.0 \n",
"255833 1.672648 -0.424023 -0.489254 -0.065882 0.0 \n",
"255834 1.120075 -0.424023 -0.489254 -0.762519 0.0 \n",
"255835 -0.167686 -0.424023 -0.489254 0.630754 0.0 \n",
"\n",
" Smoking_Yes AlcoholDrinking_Yes Stroke_Yes DiffWalking_Yes \\\n",
"0 0.0 0.0 0.0 0.0 \n",
"1 0.0 0.0 0.0 0.0 \n",
"2 0.0 0.0 0.0 0.0 \n",
"3 1.0 0.0 0.0 1.0 \n",
"4 1.0 0.0 0.0 0.0 \n",
"... ... ... ... ... \n",
"255831 1.0 0.0 0.0 0.0 \n",
"255832 1.0 0.0 0.0 0.0 \n",
"255833 0.0 0.0 0.0 0.0 \n",
"255834 1.0 0.0 0.0 0.0 \n",
"255835 1.0 0.0 0.0 0.0 \n",
"\n",
" Sex_Male ... Diabetic_Yes Diabetic_Yes (during pregnancy) \\\n",
"0 0.0 ... 0.0 0.0 \n",
"1 0.0 ... 0.0 0.0 \n",
"2 1.0 ... 0.0 0.0 \n",
"3 1.0 ... 0.0 0.0 \n",
"4 1.0 ... 0.0 0.0 \n",
"... ... ... ... ... \n",
"255831 1.0 ... 0.0 0.0 \n",
"255832 0.0 ... 0.0 0.0 \n",
"255833 1.0 ... 1.0 0.0 \n",
"255834 1.0 ... 0.0 0.0 \n",
"255835 1.0 ... 0.0 0.0 \n",
"\n",
" PhysicalActivity_Yes GenHealth_Fair GenHealth_Good GenHealth_Poor \\\n",
"0 1.0 0.0 1.0 0.0 \n",
"1 1.0 0.0 1.0 0.0 \n",
"2 1.0 0.0 0.0 0.0 \n",
"3 0.0 0.0 1.0 0.0 \n",
"4 1.0 0.0 1.0 0.0 \n",
"... ... ... ... ... \n",
"255831 1.0 0.0 0.0 0.0 \n",
"255832 1.0 0.0 0.0 0.0 \n",
"255833 1.0 0.0 1.0 0.0 \n",
"255834 0.0 0.0 0.0 0.0 \n",
"255835 1.0 0.0 0.0 0.0 \n",
"\n",
" GenHealth_Very good Asthma_Yes KidneyDisease_Yes SkinCancer_Yes \n",
"0 0.0 1.0 0.0 0.0 \n",
"1 0.0 0.0 0.0 0.0 \n",
"2 1.0 0.0 0.0 0.0 \n",
"3 0.0 0.0 0.0 0.0 \n",
"4 0.0 0.0 1.0 0.0 \n",
"... ... ... ... ... \n",
"255831 0.0 0.0 0.0 0.0 \n",
"255832 1.0 0.0 0.0 0.0 \n",
"255833 0.0 0.0 0.0 0.0 \n",
"255834 1.0 0.0 0.0 0.0 \n",
"255835 1.0 0.0 0.0 0.0 \n",
"\n",
"[255836 rows x 38 columns]"
]
},
2024-12-14 11:55:47 +04:00
"execution_count": 5,
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная...
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная.
— Возьмём?
— Конечно!
Она сзади села и в спальник забралась.
Водила сразу:
— Ну–ка, порули пока, я сейчас... — и к ней в спальник.
Через 10 минут выбрался, сел вперёд, закурил:
— Ай, хороша девка!
Напарник:
— Чё, правда хороша? Ну–ка, порули пока... — и к ней в спальник.
Через 10 минут девка вылезает, садится с водилой, закуривает:
— Ай, хорош у тебя напарник!
Водила:
— Чё правда так хорош? Ну–ка, порули...
2024-12-14 00:24:39 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Обучение моделей"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Модель: RandomForestClassifier\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 1.0\n",
"\tRecall_test: 1.0\n",
"\tAccuracy_train: 1.0\n",
"\tAccuracy_test: 1.0\n",
"\tF1_train: 1.0\n",
"\tF1_test: 1.0\n",
"\tROC_AUC_test: 1.0\n",
"\tCohen_kappa_test: 1.0\n",
"\tMCC_test: 1.0\n",
"\tConfusion_matrix: [[58484 0]\n",
" [ 0 5475]]\n",
"\n",
"Модель: LogisticRegression\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 1.0\n",
"\tRecall_test: 1.0\n",
"\tAccuracy_train: 1.0\n",
"\tAccuracy_test: 1.0\n",
"\tF1_train: 1.0\n",
"\tF1_test: 1.0\n",
"\tROC_AUC_test: 1.0\n",
"\tCohen_kappa_test: 1.0\n",
"\tMCC_test: 1.0\n",
"\tConfusion_matrix: [[58484 0]\n",
" [ 0 5475]]\n",
"\n",
"Модель: KNN\n",
"\tPrecision_train: 0.9985596365852307\n",
"\tPrecision_test: 0.9934322549258088\n",
"\tRecall_train: 0.8231345328340488\n",
"\tRecall_test: 0.7459360730593607\n",
"\tAccuracy_train: 0.9847597679763599\n",
"\tAccuracy_test: 0.9778295470535812\n",
"\tF1_train: 0.9024005607149115\n",
"\tF1_test: 0.8520759440851241\n",
"\tROC_AUC_test: 0.872737204165273\n",
"\tCohen_kappa_test: 0.8403545562876147\n",
"\tMCC_test: 0.8504421479558301\n",
"\tConfusion_matrix: [[58457 27]\n",
" [ 1391 4084]]\n",
"\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn import metrics\n",
"\n",
"\n",
"# Оценка качества различных моделей на основе метрик\n",
"def evaluate_models(models, \n",
" pipeline_end: Pipeline, \n",
" X_train: DataFrame, y_train, \n",
" X_test: DataFrame, y_test):\n",
" results = {}\n",
" \n",
" for model_name, model in models.items():\n",
" # Создание конвейера для текущей модели\n",
" model_pipeline = Pipeline(\n",
" [\n",
" (\"pipeline\", pipeline_end), \n",
" (\"model\", model),\n",
" ]\n",
" )\n",
" \n",
" # Обучение модели\n",
" model_pipeline.fit(X_train, y_train)\n",
" \n",
" # Предсказание для обучающей и тестовой выборки\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_predict = model_pipeline.predict(X_test)\n",
" \n",
" # Вычисление метрик для текущей модели\n",
" metrics_dict = {\n",
" \"Precision_train\": metrics.precision_score(y_train, y_train_predict),\n",
" \"Precision_test\": metrics.precision_score(y_test, y_test_predict),\n",
" \"Recall_train\": metrics.recall_score(y_train, y_train_predict),\n",
" \"Recall_test\": metrics.recall_score(y_test, y_test_predict),\n",
" \"Accuracy_train\": metrics.accuracy_score(y_train, y_train_predict),\n",
" \"Accuracy_test\": metrics.accuracy_score(y_test, y_test_predict),\n",
" \"F1_train\": metrics.f1_score(y_train, y_train_predict),\n",
" \"F1_test\": metrics.f1_score(y_test, y_test_predict),\n",
" \"ROC_AUC_test\": metrics.roc_auc_score(y_test, y_test_predict),\n",
" \"Cohen_kappa_test\": metrics.cohen_kappa_score(y_test, y_test_predict),\n",
" \"MCC_test\": metrics.matthews_corrcoef(y_test, y_test_predict),\n",
" \"Confusion_matrix\": metrics.confusion_matrix(y_test, y_test_predict),\n",
" }\n",
" \n",
" # Сохранение результатов\n",
" results[model_name] = metrics_dict\n",
" \n",
" return results\n",
"\n",
"\n",
"# Выбранные модели для классификации\n",
"models_classification = {\n",
" \"RandomForestClassifier\": RandomForestClassifier(random_state=42),\n",
" \"LogisticRegression\": LogisticRegression(max_iter=1000),\n",
" \"KNN\": KNeighborsClassifier(),\n",
"}\n",
"\n",
"results = evaluate_models(models_classification,\n",
" pipeline_end,\n",
" X_train, y_train,\n",
" X_test, y_test)\n",
"\n",
"# Вывод результатов\n",
"for model_name, metrics_dict in results.items():\n",
" print(f\"Модель: {model_name}\")\n",
" for metric_name, value in metrics_dict.items():\n",
" print(f\"\\t{metric_name}: {value}\")\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Результаты:**\n",
"\n",
"1. **Случайный лес (Random Forest)**:\n",
" - Метрики:\n",
" * Precision (обучение): 1.0\n",
" * Precision (тест): 1.0\n",
" * Recall (обучение): 1.0\n",
" * Recall (тест): 1.0\n",
" * Accuracy (обучение): 1.0\n",
" * Accuracy (тест): 1.0\n",
" * F1 Score (обучение): 1.0\n",
" * F1 Score (тест): 1.0\n",
" * ROC AUC (тест): 1.0\n",
" * Cohen Kappa (тест): 1.0\n",
" * MCC (тест): 1.0\n",
" * Confusion Matrix (тест): \n",
" ```\n",
" [[58484 0]\n",
" [ 0 5475]]\n",
" ```\n",
" - ***Вывод***: модель продемонстрировала идеальные результаты как на обучающей, так и на тестовой выборке. В с е метрики (Precision, Recall, Accuracy, F1 Score, ROC AUC и др.) равны 1.0, что свидетельствует о 100%-й точности классификации. Вероятно, модель переобучилась, так как такие результаты практически невозможно достичь на реальных данных. Необходимо проверить данные, наличие утечек информации (например, коррелирующих с целевой переменной признаков), а также параметры модели.\n",
"\n",
"2. **Логистическая регрессия (Logistic Regression)**:\n",
" - Метрики:\n",
" * Precision (обучение): 1.0\n",
" * Precision (тест): 1.0\n",
" * Recall (обучение): 1.0\n",
" * Recall (тест): 1.0\n",
" * Accuracy (обучение): 1.0\n",
" * Accuracy (тест): 1.0\n",
" * F1 Score (обучение): 1.0\n",
" * F1 Score (тест): 1.0\n",
" * ROC AUC (тест): 1.0\n",
" * Cohen Kappa (тест): 1.0\n",
" * MCC (тест): 1.0\n",
" * Confusion Matrix (тест): \n",
" ```\n",
" [[58484 0]\n",
" [ 0 5475]]\n",
" ```\n",
" - ***Вывод***: аналогично Random Forest, результаты выглядят идеальными: все метрики равны 1.0, включая ROC AUC, что предполагает идеальную работу модели. Здесь также есть признаки переобучения или утечки данных. Необходимо пересмотреть подготовку данных и проверить конвейер обработки (особенно этапы предобработки).\n",
"\n",
"3. **Метод ближайших соседей (KNN)**:\n",
" - Метрики:\n",
" * Precision (обучение): 0.999\n",
" * Precision (тест): 0.993\n",
" * Recall (обучение): 0.823\n",
" * Recall (тест): 0.746\n",
" * Accuracy (обучение): 0.985\n",
" * Accuracy (тест): 0.978\n",
" * F1 Score (обучение): 0.902\n",
" * F1 Score (тест): 0.852\n",
" * ROC AUC (тест): 0.872\n",
" * Cohen Kappa (тест): 0.840\n",
" * MCC (тест): 0.850\n",
" * Confusion Matrix (тест): \n",
" ```\n",
" [[58457 27]\n",
" [ 1391 4084]]\n",
" ```\n",
" - ***Вывод***: модель KNN выглядит наиболее реалистичной и стабильной, но уступает случайному лесу и логистической регрессии в точности. Эта модель является хорошей точкой отсчета для дальнейших экспериментов."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Матрица неточностей"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA/QAAAQTCAYAAADKw2LWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAADV0klEQVR4nOzdfXzN9f/H8efZZhdsZ3O1zTIXkYuVXKzSvq5ZlqSEb5FqLruiQqi+lasuREkKUcp8+1ZyURLlIrnIRQopCSGFmIuwITY75/P7w28nx9Bmn7OdzzmP++32uX3zOe+9z/scvnue13m/P++PzTAMQwAAAAAAwFICinsAAAAAAACg4CjoAQAAAACwIAp6AAAAAAAsiIIeAAAAAAALoqAHAAAAAMCCKOgBAAAAALAgCnoAAAAAACwoqLgHAACAmU6fPq3s7GxT+wwODlZoaKipfQIA4C/IZs+hoAcA+IzTp0+rauVwpR90mNpvbGysdu3axQcHAAAKiGz2LAp6AIDPyM7OVvpBh35fX0X2CHOuKss87lTlxN+UnZ3t9x8aAAAoKLLZsyjoAQA+JzzCpvAImyl9OWVOPwAA+DOy2TPYFA8AAAAAAAtihh4A4HMchlMOw7y+AABA4ZDNnkFBDwDwOU4ZcsqcTw1m9QMAgD8jmz2DJfcAAAAAAFgQM/QAAJ/jlFNmLcYzrycAAPwX2ewZFPQAAJ/jMAw5DHOW45nVDwAA/oxs9gyW3AMAAAAAYEEU9PAZ3bp1U5UqVYp7GH7lt99+k81mU1paWrGNoUqVKurWrZvbue3bt6t169aKjIyUzWbTnDlzlJaWJpvNpt9++61YxomilbvxjlkHgKLTvHlzNW/e3LT+LpQTyJ9ly5bJZrNp2bJlxT0U+ACy2TMo6HFZcouj3CMoKEhXXHGFunXrpj/++KO4h1fkunXr5vZ+nHssWLCguIeXx759+zRs2DBt3Ljxom2WLVumDh06KDY2VsHBwYqOjla7du308ccfF91AL1Nqaqo2bdqkF154Qe+9956uu+664h4SAFhSbt6vW7euuIdySatXr9awYcN07NgxU/rL/cI69wgICFCZMmXUpk0brVmzxpTnAAAzcA09CmXEiBGqWrWqTp8+rW+++UZpaWlauXKlfvrpJ4WGhhb38IpUSEiIpkyZkud83bp1i2E0l7Zv3z4NHz5cVapUUb169fI8PnToUI0YMUJXXXWVHnjgAVWuXFl//vmnPv/8c3Xs2FHvv/++7r777qIf+AVs27ZNAQF/fzd56tQprVmzRk8//bT69u3rOn/vvfeqc+fOCgkJKY5hoog5ZcjBrXEAS1q0aFGBf2b16tUaPny4unXrpqioKLfHzs+JgujSpYtuueUWORwO/fLLL5o4caJatGih7777TnXq1LmsPq2kadOmOnXqlIKDg4t7KPABZLNnUNCjUNq0aeOa/ezVq5fKlSunUaNGae7cubrzzjuLeXRFKygoSPfcc49H+v7rr79UsmRJj/R9vlmzZmnEiBHq1KmTPvjgA5UoUcL12KBBg7Rw4UKdOXOmSMaSH+cX6IcOHZKkPB/oAgMDFRgYaNrznjx5UqVKlTKtP5iLe90C1mV28ViYL3IbNGjglu1NmjRRmzZt9Oabb2rixIlmDC/fiiN3AgIC/G6CBp5DNnsGS+5hqiZNmkiSdu7cKUnKzs7WkCFDlJiYqMjISJUqVUpNmjTR0qVL3X4ud2nbK6+8orfeekvVqlVTSEiIrr/+en333Xd5nmfOnDm65pprFBoaqmuuuUaffPLJBcdz8uRJPf7444qPj1dISIhq1qypV155RcZ5O2PabDb17dtXM2fOVEJCgsLCwpSUlKRNmzZJkiZPnqzq1asrNDRUzZs3v+zrsCdOnKirr75aISEhiouLU58+ffIsD2zevLmuueYarV+/Xk2bNlXJkiX1n//8R5KUlZWloUOHqnr16goJCVF8fLwGDx6srKwstz4WL16sxo0bKyoqSuHh4apZs6arj2XLlun666+XJHXv3t21nDD3Ovhnn31WZcqU0bvvvutWzOdKSUnRrbfeetHX+OOPP6pbt2668sorFRoaqtjYWPXo0UN//vmnW7vjx4+rX79+qlKlikJCQhQdHa2bbrpJGzZscLXZvn27OnbsqNjYWIWGhqpixYrq3LmzMjIyXG3OvTZy2LBhqly5sqSzXz7YbDbXvgoXu4b+iy++UJMmTVSqVClFRESobdu22rx5s1ubbt26KTw8XDt37tQtt9yiiIgIde3a9aLvAQD4k++//15t2rSR3W5XeHi4WrVqpW+++SZPux9//FHNmjVTWFiYKlasqOeff15Tp07N87v5QtfQv/HGG7r66qtVsmRJlS5dWtddd50++OADSWd/9w8aNEiSVLVqVVeu5fZ5oWvojx07pv79+7syqGLFirrvvvt0+PDhS77W8z/nnNtfv379XJ83qlevrlGjRsnpdL+11p9//ql7771XdrtdUVFRSk1N1Q8//JBnP5pL5Y7T6dRrr72mq6++WqGhoYqJidEDDzygo0ePuj3XunXrlJKSonLlyiksLExVq1ZVjx493NpMnz5diYmJioiIkN1uV506dTRu3DjX4xe7hn7mzJlKTExUWFiYypUrp3vuuSfPJZe5r+GPP/5Q+/btFR4ervLly2vgwIFyOByXfJ8B5B8z9DBVbniWLl1akpSZmakpU6aoS5cu6t27t44fP6533nlHKSkp+vbbb/Ms9/7ggw90/PhxPfDAA7LZbBo9erQ6dOigX3/91VVcLlq0SB07dlRCQoJGjhypP//8U927d1fFihXd+jIMQ7fddpuWLl2qnj17ql69elq4cKEGDRqkP/74Q2PHjnVr//XXX2vu3Lnq06ePJGnkyJG69dZbNXjwYE2cOFEPP/ywjh49qtGjR6tHjx766quv8rz+8z8IlChRQpGRkZLOfuAYPny4kpOT9dBDD2nbtm1688039d1332nVqlVuxfOff/6pNm3aqHPnzrrnnnsUExMjp9Op2267TStXrtT999+v2rVra9OmTRo7dqx++eUXzZkzR5K0efNm3Xrrrbr22ms1YsQIhYSEaMeOHVq1apUkqXbt2hoxYoSGDBmi+++/3/Xh5F//+pe2b9+urVu3qkePHoqIiMjX3/n5Fi9erF9//VXdu3dXbGysNm/erLfeekubN2/WN998I5vNJkl68MEHNWvWLPXt21cJCQn6888/tXLlSm3ZskUNGjRQdna2UlJSlJWVpUceeUSxsbH6448/NG/ePB07dsz1vp6rQ4cOioqKUv/+/V3LJMPDwy861vfee0+pqalKSUnRqFGj9Ndff+nNN99U48aN9f3337ttspiTk6OUlBQ1btxYr7zySpGtmMDl4dY4QNHYvHmzmjRpIrvdrsGDB6tEiRKaPHmymjdvruXLl6thw4aSpD/++EMtWrSQzWbTU089pVKlSmnKlCn5mj1/++239eijj6pTp0567LHHdPr0af34449au3at7r77bnXo0EG//PKLPvzwQ40dO1blypWTJJUvX/6C/Z04cUJNmjTRli1b1KNHDzVo0ECHDx/W3LlztXfvXtfPX8j5n3Oks6vomjVrpj/++EMPPPCAKlWqpNWrV+upp57S/v379dprr0k6W4i3a9dO3377rR566CHVqlVLn376qVJTUy/4XBfLnQceeEBpaWnq3r27Hn30Ue3atUvjx4/X999/7/o8cfDgQbVu3Vrly5fXk08+qaioKP32229u++AsXrxYXbp0UatWrTRq1ChJ0pYtW7Rq1So99thjF30Pcp/7+uuv18iRI3XgwAGNGzdOq1at0vfff++2Qs7hcCglJUUNGzbUK6+8oi+//FJjxoxRtWrV9NBDD130OeCbyGYPMYDLMHXqVEOS8eWXXxqHDh0y9uzZY8yaNcsoX768ERISYuzZs8cwDMPIyckxsrKy3H726NGjRkxMjNGjRw/XuV27dhmSjLJlyxpHjhxxnf/0008NScZnn33mOlevXj2jQoUKxrFjx1znFi1aZEgyKleu7Do3Z84cQ5Lx/PPPuz1/p06
"text/plain": [
"<Figure size 1200x1000 with 7 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"from math import ceil\n",
"\n",
"_, ax = plt.subplots(ceil(len(models_classification) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"\n",
"for index, key in enumerate(models_classification.keys()):\n",
" c_matrix = results[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Yes\", \"No\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Подбор гиперпараметров"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Oleg\\Desktop\\AIM_ForLab4\\lab_4\\aimenv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие параметры: {'model__criterion': 'gini', 'model__max_depth': 5, 'model__max_features': 'sqrt', 'model__n_estimators': 100}\n"
]
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"# Создание конвейера\n",
"pipeline = Pipeline([\n",
" (\"processing\", pipeline_end),\n",
" (\"model\", RandomForestClassifier(random_state=42))\n",
"])\n",
"\n",
"# Установка параметров для поиска по сетке\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 50, 100],\n",
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
" \"model__max_depth\": [5, 7, 10],\n",
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью поиска по сетке\n",
"grid_search = GridSearchCV(estimator=pipeline, \n",
" param_grid=param_grid,\n",
" n_jobs=-1)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"grid_search.fit(X_train, y_train)\n",
"\n",
"# Результаты подбора гиперпараметров\n",
"print(\"Лучшие параметры:\", grid_search.best_params_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Сравнение наборов гиперпараметров"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Стоковая модель:\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 1.0\n",
"\tRecall_test: 1.0\n",
"\tAccuracy_train: 1.0\n",
"\tAccuracy_test: 1.0\n",
"\tF1_train: 1.0\n",
"\tF1_test: 1.0\n",
"\tROC_AUC_test: 1.0\n",
"\tCohen_kappa_test: 1.0\n",
"\tMCC_test: 1.0\n",
"\tConfusion_matrix: [[58484 0]\n",
" [ 0 5475]]\n",
"\n",
"Оптимизированная модель:\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 0.9995433372910768\n",
"\tRecall_test: 0.9994520547945206\n",
"\tAccuracy_train: 0.9999609124595444\n",
"\tAccuracy_test: 0.9999530949514532\n",
"\tF1_train: 0.9997716164984242\n",
"\tF1_test: 0.9997259523157029\n",
"\tROC_AUC_test: 0.9997260273972604\n",
"\tCohen_kappa_test: 0.9997003049351247\n",
"\tMCC_test: 0.9997003498302348\n",
"\tConfusion_matrix: [[58484 0]\n",
" [ 3 5472]]\n"
]
}
],
"source": [
"# Обучение модели с о старыми гипермараметрами\n",
"pipeline.fit(X_train, y_train)\n",
"\n",
"# Предсказание для обучающей и тестовой выборки\n",
"y_train_predict = pipeline.predict(X_train)\n",
"y_test_predict = pipeline.predict(X_test)\n",
" \n",
"# Вычисление метрик для модели с о старыми гипермараметрами\n",
"base_model_metrics = {\n",
" \"Precision_train\": metrics.precision_score(y_train, y_train_predict),\n",
" \"Precision_test\": metrics.precision_score(y_test, y_test_predict),\n",
" \"Recall_train\": metrics.recall_score(y_train, y_train_predict),\n",
" \"Recall_test\": metrics.recall_score(y_test, y_test_predict),\n",
" \"Accuracy_train\": metrics.accuracy_score(y_train, y_train_predict),\n",
" \"Accuracy_test\": metrics.accuracy_score(y_test, y_test_predict),\n",
" \"F1_train\": metrics.f1_score(y_train, y_train_predict),\n",
" \"F1_test\": metrics.f1_score(y_test, y_test_predict),\n",
" \"ROC_AUC_test\": metrics.roc_auc_score(y_test, y_test_predict),\n",
" \"Cohen_kappa_test\": metrics.cohen_kappa_score(y_test, y_test_predict),\n",
" \"MCC_test\": metrics.matthews_corrcoef(y_test, y_test_predict),\n",
" \"Confusion_matrix\": metrics.confusion_matrix(y_test, y_test_predict),\n",
"}\n",
"\n",
"# Модель с новыми гипермараметрами\n",
"optimized_model = RandomForestClassifier(\n",
" random_state=42,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"sqrt\",\n",
" n_estimators=10,\n",
")\n",
"\n",
"# Создание конвейера для модели с новыми гипермараметрами\n",
"optimized_model_pipeline = Pipeline(\n",
" [\n",
" (\"pipeline\", pipeline_end), \n",
" (\"model\", optimized_model),\n",
" ]\n",
")\n",
" \n",
"# Обучение модели с новыми гипермараметрами\n",
"optimized_model_pipeline.fit(X_train, y_train)\n",
" \n",
"# Предсказание для обучающей и тестовой выборки\n",
"y_train_predict = optimized_model_pipeline.predict(X_train)\n",
"y_test_predict = optimized_model_pipeline.predict(X_test)\n",
" \n",
"# Вычисление метрик для модели с новыми гипермараметрами\n",
"optimized_model_metrics = {\n",
" \"Precision_train\": metrics.precision_score(y_train, y_train_predict),\n",
" \"Precision_test\": metrics.precision_score(y_test, y_test_predict),\n",
" \"Recall_train\": metrics.recall_score(y_train, y_train_predict),\n",
" \"Recall_test\": metrics.recall_score(y_test, y_test_predict),\n",
" \"Accuracy_train\": metrics.accuracy_score(y_train, y_train_predict),\n",
" \"Accuracy_test\": metrics.accuracy_score(y_test, y_test_predict),\n",
" \"F1_train\": metrics.f1_score(y_train, y_train_predict),\n",
" \"F1_test\": metrics.f1_score(y_test, y_test_predict),\n",
" \"ROC_AUC_test\": metrics.roc_auc_score(y_test, y_test_predict),\n",
" \"Cohen_kappa_test\": metrics.cohen_kappa_score(y_test, y_test_predict),\n",
" \"MCC_test\": metrics.matthews_corrcoef(y_test, y_test_predict),\n",
" \"Confusion_matrix\": metrics.confusion_matrix(y_test, y_test_predict),\n",
"}\n",
"\n",
"# Вывод информации\n",
"print('Стоковая модель:')\n",
"for metric_name, value in base_model_metrics.items():\n",
" print(f\"\\t{metric_name}: {value}\")\n",
"\n",
"print('\\nО птимизир о ва нна я модель:')\n",
"for metric_name, value in optimized_model_metrics.items():\n",
" print(f\"\\t{metric_name}: {value}\")"
]
2024-12-14 11:55:47 +04:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Бизнес-цель №2. Задача регрессии\n",
"\n",
"### Описание бизнес-цели\n",
"\n",
"**Цель**: прогнозирование количества дней с плохим физическим здоровьем. Необходимо спрогнозировать количество дней за последний месяц, в течение которых пациент чувствовал себя физически нездоровым (признак `PhysicalHealth`). Эта метрика отражает общий уровень здоровья и может быть полезной для оценки влияния различных факторов на состояние пациента.\n",
"\n",
"### Достижимый уровень качества модели\n",
"\n",
"**Основные метрики для регрессии:**\n",
"\n",
"- **Средняя абсолютная ошибка** (*Mean Absolute Error, MAE*) – показывает среднее абсолютное отклонение между предсказанными и фактическими значениями. Легко интерпретируется, особенно в финансовых данных, где каждая ошибка в долларах имеет значение.\n",
"- **Среднеквадратичная ошибка** (*Mean Squared Error, MSE*) – показывает, насколько отклоняются прогнозы модели от истинных значений в квадрате. Подходит для оценки общего качества модели.\n",
"- **Коэффициент детерминации** (*R²*) – указывает, какую долю дисперсии зависимой переменной объясняет модель. R² варьируется от 0 до 1 (чем ближе к 1, тем лучше).\n",
"\n",
"### Выбор ориентира\n",
"\n",
"В качестве базовой модели для оценки качества предсказаний выбрано использование среднего значения целевого признака `PhysicalHealth` на обучающей выборке. Это простой и интуитивно понятный метод, который служит минимальным ориентиром для сравнения с более сложными моделями. Базовая модель помогает установить начальный уровень ошибок (MAE, MSE) и показатель качества (R²), которые сложные модели должны улучшить, чтобы оправдать своё использование.\n",
"\n",
"### Разбиение набора данных на выборки\n",
"\n",
"Выполним разбиение исходного набора на **обучающую** (80%) и **тестовую** (20%) выборки:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" <th>BMI</th>\n",
" <th>Smoking</th>\n",
" <th>AlcoholDrinking</th>\n",
" <th>Stroke</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>DiffWalking</th>\n",
" <th>Sex</th>\n",
" <th>AgeCategory</th>\n",
" <th>Race</th>\n",
" <th>Diabetic</th>\n",
" <th>PhysicalActivity</th>\n",
" <th>GenHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>Asthma</th>\n",
" <th>KidneyDisease</th>\n",
" <th>SkinCancer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>46650</th>\n",
" <td>No</td>\n",
" <td>30.90</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>30.0</td>\n",
" <td>0.0</td>\n",
" <td>Yes</td>\n",
" <td>Female</td>\n",
" <td>70-74</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Poor</td>\n",
" <td>7.0</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>305695</th>\n",
" <td>No</td>\n",
" <td>23.75</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>45-49</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Excellent</td>\n",
" <td>6.0</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17353</th>\n",
" <td>No</td>\n",
" <td>34.70</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>2.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>70-74</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>154614</th>\n",
" <td>No</td>\n",
" <td>26.37</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Yes</td>\n",
" <td>Female</td>\n",
" <td>80 or older</td>\n",
" <td>Black</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Fair</td>\n",
" <td>4.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>146811</th>\n",
" <td>No</td>\n",
" <td>18.79</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>5.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>18-24</td>\n",
" <td>Other</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>224078</th>\n",
" <td>No</td>\n",
" <td>24.13</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>45-49</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Very good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14534</th>\n",
" <td>No</td>\n",
" <td>22.32</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>50-54</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Excellent</td>\n",
" <td>5.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>156850</th>\n",
" <td>No</td>\n",
" <td>23.78</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>2.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>35-39</td>\n",
" <td>Black</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>221285</th>\n",
" <td>No</td>\n",
" <td>26.52</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>65-69</td>\n",
" <td>Hispanic</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16625</th>\n",
" <td>No</td>\n",
" <td>23.57</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>40-44</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>255836 rows × 18 columns</p>\n",
"</div>"
],
"text/plain": [
" HeartDisease BMI Smoking AlcoholDrinking Stroke PhysicalHealth \\\n",
"46650 No 30.90 No No No 30.0 \n",
"305695 No 23.75 No Yes No 0.0 \n",
"17353 No 34.70 No No No 0.0 \n",
"154614 No 26.37 Yes No No 0.0 \n",
"146811 No 18.79 No No No 0.0 \n",
"... ... ... ... ... ... ... \n",
"224078 No 24.13 Yes No No 0.0 \n",
"14534 No 22.32 Yes No No 0.0 \n",
"156850 No 23.78 Yes No No 2.0 \n",
"221285 No 26.52 No No No 0.0 \n",
"16625 No 23.57 No No No 0.0 \n",
"\n",
" MentalHealth DiffWalking Sex AgeCategory Race Diabetic \\\n",
"46650 0.0 Yes Female 70-74 White Yes \n",
"305695 0.0 No Male 45-49 White No \n",
"17353 2.0 No Female 70-74 White No \n",
"154614 0.0 Yes Female 80 or older Black Yes \n",
"146811 5.0 No Female 18-24 Other No \n",
"... ... ... ... ... ... ... \n",
"224078 0.0 No Female 45-49 White No \n",
"14534 0.0 No Female 50-54 White No \n",
"156850 0.0 No Female 35-39 Black No \n",
"221285 0.0 No Female 65-69 Hispanic Yes \n",
"16625 0.0 No Male 40-44 White No \n",
"\n",
" PhysicalActivity GenHealth SleepTime Asthma KidneyDisease SkinCancer \n",
"46650 No Poor 7.0 Yes No Yes \n",
"305695 Yes Excellent 6.0 Yes Yes No \n",
"17353 No Good 7.0 No No Yes \n",
"154614 No Fair 4.0 No No No \n",
"146811 Yes Very good 6.0 No No No \n",
"... ... ... ... ... ... ... \n",
"224078 No Very good 8.0 No No No \n",
"14534 Yes Excellent 5.0 No No No \n",
"156850 Yes Very good 6.0 No No No \n",
"221285 Yes Good 8.0 No No No \n",
"16625 Yes Good 7.0 No No No \n",
"\n",
"[255836 rows x 18 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>PhysicalHealth</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>46650</th>\n",
" <td>30.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>305695</th>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17353</th>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>154614</th>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>146811</th>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>224078</th>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14534</th>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>156850</th>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>221285</th>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16625</th>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>255836 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" PhysicalHealth\n",
"46650 30.0\n",
"305695 0.0\n",
"17353 0.0\n",
"154614 0.0\n",
"146811 0.0\n",
"... ...\n",
"224078 0.0\n",
"14534 0.0\n",
"156850 2.0\n",
"221285 0.0\n",
"16625 0.0\n",
"\n",
"[255836 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" <th>BMI</th>\n",
" <th>Smoking</th>\n",
" <th>AlcoholDrinking</th>\n",
" <th>Stroke</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>DiffWalking</th>\n",
" <th>Sex</th>\n",
" <th>AgeCategory</th>\n",
" <th>Race</th>\n",
" <th>Diabetic</th>\n",
" <th>PhysicalActivity</th>\n",
" <th>GenHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>Asthma</th>\n",
" <th>KidneyDisease</th>\n",
" <th>SkinCancer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>146589</th>\n",
" <td>No</td>\n",
" <td>19.45</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>Yes</td>\n",
" <td>Male</td>\n",
" <td>25-29</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>12.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>216017</th>\n",
" <td>No</td>\n",
" <td>26.36</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>40-44</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19624</th>\n",
" <td>No</td>\n",
" <td>24.59</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Yes</td>\n",
" <td>Male</td>\n",
" <td>55-59</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>65923</th>\n",
" <td>No</td>\n",
" <td>23.44</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>20.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>60-64</td>\n",
" <td>Asian</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63362</th>\n",
" <td>No</td>\n",
" <td>31.32</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>2.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>45-49</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Very good</td>\n",
" <td>6.0</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>252474</th>\n",
" <td>No</td>\n",
" <td>42.37</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>1.0</td>\n",
" <td>5.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>18-24</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>8.0</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>147913</th>\n",
" <td>No</td>\n",
" <td>32.08</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>40-44</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Excellent</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>244674</th>\n",
" <td>No</td>\n",
" <td>31.28</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>50-54</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>215373</th>\n",
" <td>No</td>\n",
" <td>31.65</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>45-49</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>179461</th>\n",
" <td>No</td>\n",
" <td>27.37</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>65-69</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>63959 rows × 18 columns</p>\n",
"</div>"
],
"text/plain": [
" HeartDisease BMI Smoking AlcoholDrinking Stroke PhysicalHealth \\\n",
"146589 No 19.45 No No No 1.0 \n",
"216017 No 26.36 No No No 0.0 \n",
"19624 No 24.59 Yes No No 0.0 \n",
"65923 No 23.44 Yes No No 0.0 \n",
"63362 No 31.32 No No No 0.0 \n",
"... ... ... ... ... ... ... \n",
"252474 No 42.37 No No No 1.0 \n",
"147913 No 32.08 Yes No No 0.0 \n",
"244674 No 31.28 No No No 3.0 \n",
"215373 No 31.65 No No No 0.0 \n",
"179461 No 27.37 No No No 0.0 \n",
"\n",
" MentalHealth DiffWalking Sex AgeCategory Race Diabetic \\\n",
"146589 0.0 Yes Male 25-29 White No \n",
"216017 0.0 No Female 40-44 White No \n",
"19624 0.0 Yes Male 55-59 White No \n",
"65923 20.0 No Female 60-64 Asian No \n",
"63362 2.0 No Female 45-49 White No \n",
"... ... ... ... ... ... ... \n",
"252474 5.0 No Female 18-24 White No \n",
"147913 0.0 No Male 40-44 White No \n",
"244674 0.0 No Female 50-54 White Yes \n",
"215373 0.0 No Male 45-49 White No \n",
"179461 0.0 No Male 65-69 White Yes \n",
"\n",
" PhysicalActivity GenHealth SleepTime Asthma KidneyDisease SkinCancer \n",
"146589 Yes Good 12.0 No No No \n",
"216017 Yes Very good 7.0 No No No \n",
"19624 Yes Good 6.0 No No No \n",
"65923 Yes Very good 6.0 No No No \n",
"63362 No Very good 6.0 Yes No No \n",
"... ... ... ... ... ... ... \n",
"252474 Yes Good 8.0 Yes No No \n",
"147913 Yes Excellent 8.0 No No No \n",
"244674 Yes Good 7.0 No No Yes \n",
"215373 Yes Very good 8.0 No No No \n",
"179461 No Good 7.0 No No No \n",
"\n",
"[63959 rows x 18 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>PhysicalHealth</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>146589</th>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>216017</th>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19624</th>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>65923</th>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63362</th>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>252474</th>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>147913</th>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>244674</th>\n",
" <td>3.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>215373</th>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>179461</th>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>63959 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" PhysicalHealth\n",
"146589 1.0\n",
"216017 0.0\n",
"19624 0.0\n",
"65923 0.0\n",
"63362 0.0\n",
"... ...\n",
"252474 1.0\n",
"147913 0.0\n",
"244674 3.0\n",
"215373 0.0\n",
"179461 0.0\n",
"\n",
"[63959 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"X_df_train, X_df_val, X_df_test, y_df_train, y_df_val, y_df_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname='PhysicalHealth', frac_train=0.8, frac_val=0, frac_test=0.2, random_state=9\n",
")\n",
"\n",
"display(\"X_train\", X_df_train)\n",
"display(\"y_train\", y_df_train)\n",
"\n",
"display(\"X_test\", X_df_test)\n",
"display(\"y_test\", y_df_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Построим **базовую модель**, описанную выше, и оценим е е метрики *MAE*, *MSE* и *R²*:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Baseline MAE: 5.081172924146543\n",
"Baseline MSE: 63.21384755665578\n",
"Baseline R²: -6.286438036795516e-10\n"
]
}
],
"source": [
"from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
"\n",
"# Вычисляем предсказания базовой модели (среднее значение целевой переменной)\n",
"baseline_predictions = [y_df_train.mean()] * len(y_df_test) # type: ignore\n",
"\n",
"# Оцениваем базовую модель\n",
"print('Baseline MAE:', mean_absolute_error(y_df_test, baseline_predictions))\n",
"print('Baseline MSE:', mean_squared_error(y_df_test, baseline_predictions))\n",
"print('Baseline R²:', r2_score(y_df_test, baseline_predictions))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Выбор моделей обучения\n",
"\n",
"Для обучения были выбраны следующие модели:\n",
"\n",
"1. **Случайный лес** (*Random Forest*): Ансамблевая модель, которая использует множество решающих деревьев. Она хорошо справляется с нелинейными зависимостями и шумом в данных, а также обладает устойчивостью к переобучению.\n",
"2. **Линейная регрессия** (*Linear Regression*): Простая модель, предполагающая линейную зависимость между признаками и целевой переменной. Она быстро обучается и предоставляет легкую интерпретацию результатов.\n",
"3. **Градиентный бустинг** (*Gradient Boosting*): Мощная модель, создающая ансамбль деревьев, которые корректируют ошибки предыдущих. Эта модель эффективна для сложных наборов данных и обеспечивает высокую точность предсказаний."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Использование конвейера на тренировочных данных\n",
"\n",
"Конвейер уже был построен при решении задачи классификации. Применяем готовый конвейер:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>BMI</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>HeartDisease_Yes</th>\n",
" <th>Smoking_Yes</th>\n",
" <th>AlcoholDrinking_Yes</th>\n",
" <th>Stroke_Yes</th>\n",
" <th>DiffWalking_Yes</th>\n",
" <th>Sex_Male</th>\n",
" <th>...</th>\n",
" <th>Diabetic_Yes</th>\n",
" <th>Diabetic_Yes (during pregnancy)</th>\n",
" <th>PhysicalActivity_Yes</th>\n",
" <th>GenHealth_Fair</th>\n",
" <th>GenHealth_Good</th>\n",
" <th>GenHealth_Poor</th>\n",
" <th>GenHealth_Very good</th>\n",
" <th>Asthma_Yes</th>\n",
" <th>KidneyDisease_Yes</th>\n",
" <th>SkinCancer_Yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.404994</td>\n",
" <td>3.349099</td>\n",
" <td>-0.490224</td>\n",
" <td>-0.068438</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>-0.718814</td>\n",
" <td>-0.424073</td>\n",
" <td>-0.490224</td>\n",
" <td>-0.765508</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.002262</td>\n",
" <td>-0.424073</td>\n",
" <td>-0.238838</td>\n",
" <td>-0.068438</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>-0.307013</td>\n",
" <td>-0.424073</td>\n",
" <td>-0.490224</td>\n",
" <td>-2.159646</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>-1.498407</td>\n",
" <td>-0.424073</td>\n",
" <td>0.138242</td>\n",
" <td>-0.765508</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1.256887</td>\n",
" <td>1.336741</td>\n",
" <td>-0.490224</td>\n",
" <td>-0.765508</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>-0.506627</td>\n",
" <td>-0.424073</td>\n",
" <td>-0.490224</td>\n",
" <td>-0.068438</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>-1.690161</td>\n",
" <td>-0.424073</td>\n",
" <td>-0.490224</td>\n",
" <td>0.628631</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>-0.030384</td>\n",
" <td>-0.424073</td>\n",
" <td>-0.490224</td>\n",
" <td>0.628631</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>-0.167127</td>\n",
" <td>3.349099</td>\n",
" <td>1.395175</td>\n",
" <td>-0.068438</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>10 rows × 38 columns</p>\n",
"</div>"
],
"text/plain": [
" BMI PhysicalHealth MentalHealth SleepTime HeartDisease_Yes \\\n",
"0 0.404994 3.349099 -0.490224 -0.068438 0.0 \n",
"1 -0.718814 -0.424073 -0.490224 -0.765508 0.0 \n",
"2 1.002262 -0.424073 -0.238838 -0.068438 0.0 \n",
"3 -0.307013 -0.424073 -0.490224 -2.159646 0.0 \n",
"4 -1.498407 -0.424073 0.138242 -0.765508 0.0 \n",
"5 1.256887 1.336741 -0.490224 -0.765508 0.0 \n",
"6 -0.506627 -0.424073 -0.490224 -0.068438 0.0 \n",
"7 -1.690161 -0.424073 -0.490224 0.628631 0.0 \n",
"8 -0.030384 -0.424073 -0.490224 0.628631 0.0 \n",
"9 -0.167127 3.349099 1.395175 -0.068438 0.0 \n",
"\n",
" Smoking_Yes AlcoholDrinking_Yes Stroke_Yes DiffWalking_Yes Sex_Male \\\n",
"0 0.0 0.0 0.0 1.0 0.0 \n",
"1 0.0 1.0 0.0 0.0 1.0 \n",
"2 0.0 0.0 0.0 0.0 0.0 \n",
"3 1.0 0.0 0.0 1.0 0.0 \n",
"4 0.0 0.0 0.0 0.0 0.0 \n",
"5 0.0 0.0 0.0 1.0 1.0 \n",
"6 0.0 0.0 0.0 0.0 1.0 \n",
"7 0.0 0.0 0.0 0.0 0.0 \n",
"8 1.0 1.0 0.0 0.0 0.0 \n",
"9 0.0 0.0 0.0 1.0 1.0 \n",
"\n",
" ... Diabetic_Yes Diabetic_Yes (during pregnancy) PhysicalActivity_Yes \\\n",
"0 ... 1.0 0.0 0.0 \n",
"1 ... 0.0 0.0 1.0 \n",
"2 ... 0.0 0.0 0.0 \n",
"3 ... 1.0 0.0 0.0 \n",
"4 ... 0.0 0.0 1.0 \n",
"5 ... 1.0 0.0 0.0 \n",
"6 ... 0.0 0.0 1.0 \n",
"7 ... 0.0 0.0 1.0 \n",
"8 ... 0.0 0.0 1.0 \n",
"9 ... 0.0 0.0 0.0 \n",
"\n",
" GenHealth_Fair GenHealth_Good GenHealth_Poor GenHealth_Very good \\\n",
"0 0.0 0.0 1.0 0.0 \n",
"1 0.0 0.0 0.0 0.0 \n",
"2 0.0 1.0 0.0 0.0 \n",
"3 1.0 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 1.0 \n",
"5 1.0 0.0 0.0 0.0 \n",
"6 0.0 0.0 0.0 1.0 \n",
"7 0.0 0.0 0.0 1.0 \n",
"8 0.0 0.0 0.0 1.0 \n",
"9 1.0 0.0 0.0 0.0 \n",
"\n",
" Asthma_Yes KidneyDisease_Yes SkinCancer_Yes \n",
"0 1.0 0.0 1.0 \n",
"1 1.0 1.0 0.0 \n",
"2 0.0 0.0 1.0 \n",
"3 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 \n",
"5 1.0 0.0 0.0 \n",
"6 0.0 0.0 0.0 \n",
"7 0.0 0.0 0.0 \n",
"8 0.0 0.0 0.0 \n",
"9 0.0 0.0 0.0 \n",
"\n",
"[10 rows x 38 columns]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Применение конвейера\n",
"preprocessing_result = pipeline_end.fit_transform(X_df_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df.head(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Обучение моделей"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/ensemble/_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
"/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/ensemble/_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
"/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/ensemble/_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
"/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/ensemble/_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
"/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/ensemble/_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Модель: Random Forest\n",
"\tmean_score: 1.0\n",
"\tstd_dev: 0.0\n",
"\n",
"Модель: Linear Regression\n",
"\tmean_score: 1.0\n",
"\tstd_dev: 0.0\n",
"\n",
"Модель: Gradient Boosting\n",
"\tmean_score: 0.9999999324559854\n",
"\tstd_dev: 1.916515351322297e-08\n",
"\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor\n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.model_selection import cross_val_score\n",
"\n",
"\n",
"# Обучить модели\n",
"def train_models(X, y, models):\n",
" results = {}\n",
" \n",
" for model_name, model in models.items():\n",
" # Создание конвейера для текущей модели\n",
" model_pipeline = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"model\", model)\n",
" ]\n",
" )\n",
" \n",
" # Обучаем модель и вычисляем кросс-валидацию\n",
" scores = cross_val_score(model_pipeline, X, y, cv=5) # 5-кратная кросс-валидация\n",
" \n",
" # Вычисление метрик для текущей модели\n",
" metrics_dict = {\n",
" \"mean_score\": scores.mean(),\n",
" \"std_dev\": scores.std()\n",
" }\n",
" \n",
" # Сохранениерезультатов\n",
" results[model_name] = metrics_dict\n",
" \n",
" return results\n",
"\n",
"\n",
"# Выбранные модели для регрессии\n",
"models_regression = {\n",
" \"Random Forest\": RandomForestRegressor(),\n",
" \"Linear Regression\": LinearRegression(),\n",
" \"Gradient Boosting\": GradientBoostingRegressor(),\n",
"}\n",
"\n",
"results = train_models(X_df_train, y_df_train, models_regression)\n",
"\n",
"# Вывод результатов\n",
"for model_name, metrics_dict in results.items():\n",
" print(f\"Модель: {model_name}\")\n",
" for metric_name, value in metrics_dict.items():\n",
" print(f\"\\t{metric_name}: {value}\")\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Среднее значение и стандартное отклонение:**\n",
"\n",
"1. **Случайный лес (Random Forest)**:\n",
" - Метрики:\n",
" * Средний балл: 1.0\n",
" * Стандартное отклонение: 0.0\n",
" - ***Вывод***: модель случайного леса продемонстрировала идеальный результат с точностью 1.0 и без колебаний в результатах (ноль стандартного отклонения). Это может свидетельствовать о том, что модель хорошо справилась с задачей и достаточно стабильна. Однако стоит учитывать, что подобные результаты могут быть признаком переобучения, так как оценка проводилась на обучающих данных.\n",
"\n",
"1. **Линейная регрессия (Linear Regression)**:\n",
" - Метрики:\n",
" * Средний балл: 1.0\n",
" * Стандартное отклонение: 0.0\n",
" - ***Вывод***: линейная регрессия также показала идеальный результат с точностью 1.0 и нулевым отклонением. Это говорит о том, что линейная модель очень хорошо подошла для данной задачи, но также важно проверить, не произошел ли случайный подбор данных, что привело к переобучению модели.\n",
"\n",
"1. **Градиентный бустинг (Gradient Boosting)**:\n",
" - Метрики:\n",
" * Средний балл: 0.999\n",
" * Стандартное отклонение: 0.0\n",
" - ***Вывод***: Градиентный бустинг показал практически идеальный результат, также с нулевым стандартным отклонением. Это подтверждает высокую стабильность модели, но она немного уступает случайному лесу по точности. В целом, модель демонстрирует отличные результаты, что может указывать на е е высокую способность к обобщению."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Расчет метрик"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Модель: Random Forest\n",
"\tMAE_train: 0.0\n",
"\tMAE_test: 0.0\n",
"\tMSE_train: 0.0\n",
"\tMSE_test: 0.0\n",
"\tR2_train: 1.0\n",
"\tR2_test: 1.0\n",
"\tSTD_train: 0.0\n",
"\tSTD_test: 0.0\n",
"\n",
"Модель: Linear Regression\n",
"\tMAE_train: 1.194371035153155e-14\n",
"\tMAE_test: 1.1909445826766327e-14\n",
"\tMSE_train: 1.901081790225907e-28\n",
"\tMSE_test: 1.8951168132152725e-28\n",
"\tR2_train: 1.0\n",
"\tR2_test: 1.0\n",
"\tSTD_train: 9.090236366489451e-15\n",
"\tSTD_test: 9.090299369484082e-15\n",
"\n",
"Модель: Gradient Boosting\n",
"\tMAE_train: 0.00030786687422158955\n",
"\tMAE_test: 0.00030731279564540775\n",
"\tMSE_train: 4.381537207145074e-06\n",
"\tMSE_test: 4.342684206551716e-06\n",
"\tR2_train: 0.9999999306897712\n",
"\tR2_test: 0.9999999313016945\n",
"\tSTD_train: 0.0020932121744211872\n",
"\tSTD_test: 0.0020839106254309228\n",
"\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"from sklearn import metrics\n",
"\n",
"\n",
"# Оценка качества различных моделей на основе метрик\n",
"def evaluate_models(models,\n",
" pipeline_end, \n",
" X_train, y_train, \n",
" X_test, y_test):\n",
" results = {}\n",
" \n",
" for model_name, model in models.items():\n",
" # Создание конвейера для текущей модели\n",
" model_pipeline = Pipeline(\n",
" [\n",
" (\"pipeline\", pipeline_end), \n",
" (\"model\", model),\n",
" ]\n",
" )\n",
" \n",
" # Обучение текущей модели\n",
" model_pipeline.fit(X_train, y_train)\n",
"\n",
" # Предсказание для обучающей и тестовой выборки\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_predict = model_pipeline.predict(X_test)\n",
"\n",
" # Вычисление метрик для текущей модели\n",
" metrics_dict = {\n",
" \"MAE_train\": metrics.mean_absolute_error(y_train, y_train_predict),\n",
" \"MAE_test\": metrics.mean_absolute_error(y_test, y_test_predict),\n",
" \"MSE_train\": metrics.mean_squared_error(y_train, y_train_predict),\n",
" \"MSE_test\": metrics.mean_squared_error(y_test, y_test_predict),\n",
" \"R2_train\": metrics.r2_score(y_train, y_train_predict),\n",
" \"R2_test\": metrics.r2_score(y_test, y_test_predict),\n",
" \"STD_train\": np.std(y_train - y_train_predict),\n",
" \"STD_test\": np.std(y_test - y_test_predict),\n",
" }\n",
"\n",
" # Сохранение результатов\n",
" results[model_name] = metrics_dict\n",
" \n",
" return results\n",
"\n",
"\n",
"y_train = np.ravel(y_df_train) \n",
"y_test = np.ravel(y_df_test) \n",
"\n",
"results = evaluate_models(models_regression,\n",
" pipeline_end,\n",
" X_df_train, y_train,\n",
" X_df_test, y_test)\n",
"\n",
"# Вывод результатов\n",
"for model_name, metrics_dict in results.items():\n",
" print(f\"Модель: {model_name}\")\n",
" for metric_name, value in metrics_dict.items():\n",
" print(f\"\\t{metric_name}: {value}\")\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Результаты:**\n",
"\n",
"1. **Случайный лес (Random Forest)**:\n",
" - Метрики: \n",
" * MAE (обучение): 0.0\n",
" * MAE (тест): 0.0\n",
" * MSE (обучение): 0.0\n",
" * MSE (тест): 0.0\n",
" * R² (обучение): 1.0\n",
" * R² (тест): 1.0\n",
" * STD (обучение): 0.0\n",
" * STD (тест): 0.0\n",
" - ***Вывод***: модель случайного леса продемонстрировала абсолютно идеальные результаты как на обучающих, так и на тестовых данных, с нулевыми значениями ошибок и максимально возможным значением R². Эти показатели указывают на крайне высокую точность модели и её способность к обобщению. Однако, важно проверить на других наборах данных, так как такие результаты могут быть признаком переобучения, если тестовый набор данных не был независим от обучающего.\n",
"\n",
"1. **Линейная регрессия (Linear Regression)**:\n",
" - Метрики: \n",
" * MAE (обучение): 1.19e-14\n",
" * MAE (тест): 1.19e-14\n",
" * MSE (обучение): 1.90e-28\n",
" * MSE (тест): 1.86e-28\n",
" * R² (обучение): 1.0\n",
" * R² (тест): 1.0\n",
" * STD (обучение): 9.09e-15\n",
" * STD (тест): 9.09e-15\n",
" - ***Вывод***: линейная регрессия также показала выдающиеся результаты с нулевыми ошибками и максимальным R², что может свидетельствовать о её идеальной подгонке под данные. Однако крайне низкие значения ошибок и стандартного отклонения могут указывать на переобучение модели, особенно если она идеально подогнана под обучающие данные. Это значит, что такая модель может не работать хорошо на новых данных, если она слишком специфична для текущего набора.\n",
"\n",
"1. **Градиентный бустинг (Gradient Boosting)**:\n",
" - Метрики: \n",
" * MAE (обучение): 0.0\n",
" * MAE (тест): 0.0\n",
" * MSE (обучение): 4.38e-06\n",
" * MSE (тест): 4.34e-06\n",
" * R² (обучение): 1.0\n",
" * R² (тест): 1.0\n",
" * STD (обучение): 0.002\n",
" * STD (тест): 0.002\n",
" - ***Вывод***: градиентный бустинг показал отличные результаты, с минимальными ошибками и максимально возможным R², что указывает на высокую точность модели. Небольшое стандартное отклонение (около 0.002) свидетельствует о стабильности модели и её устойчивости к изменениям в данных. Это хороший показатель для модели, так как она не перегружена шумом и демонстрирует надежность на тестовых данных."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Подбор гиперпараметров"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 36 candidates, totalling 108 fits\n",
"[CV] END max_depth=None, min_samples_split=5, n_estimators=50; total time= 1.5min\n",
"[CV] END max_depth=None, min_samples_split=5, n_estimators=50; total time= 1.5min\n",
"[CV] END max_depth=None, min_samples_split=2, n_estimators=50; total time= 1.5min\n",
"[CV] END max_depth=None, min_samples_split=2, n_estimators=50; total time= 1.6min\n",
"[CV] END max_depth=None, min_samples_split=2, n_estimators=50; total time= 1.6min\n",
"[CV] END max_depth=None, min_samples_split=5, n_estimators=50; total time= 1.6min\n",
"[CV] END max_depth=None, min_samples_split=5, n_estimators=100; total time= 3.0min\n",
"[CV] END max_depth=None, min_samples_split=2, n_estimators=100; total time= 3.0min\n",
"[CV] END max_depth=None, min_samples_split=2, n_estimators=100; total time= 3.1min\n",
"[CV] END max_depth=None, min_samples_split=10, n_estimators=50; total time= 1.5min\n",
"[CV] END max_depth=None, min_samples_split=2, n_estimators=100; total time= 3.1min\n",
"[CV] END max_depth=None, min_samples_split=5, n_estimators=100; total time= 3.1min\n",
"[CV] END max_depth=None, min_samples_split=5, n_estimators=100; total time= 3.2min\n",
"[CV] END max_depth=None, min_samples_split=10, n_estimators=50; total time= 1.6min\n",
"[CV] END max_depth=None, min_samples_split=10, n_estimators=50; total time= 1.6min\n",
"[CV] END .max_depth=10, min_samples_split=2, n_estimators=50; total time= 1.7min\n",
"[CV] END max_depth=None, min_samples_split=10, n_estimators=100; total time= 3.2min\n",
"[CV] END .max_depth=10, min_samples_split=2, n_estimators=50; total time= 1.7min\n",
"[CV] END .max_depth=10, min_samples_split=2, n_estimators=50; total time= 1.8min\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[11], line 24\u001b[0m\n\u001b[1;32m 19\u001b[0m grid_search \u001b[38;5;241m=\u001b[39m GridSearchCV(estimator\u001b[38;5;241m=\u001b[39mmodel, \n\u001b[1;32m 20\u001b[0m param_grid\u001b[38;5;241m=\u001b[39mparam_grid,\n\u001b[1;32m 21\u001b[0m scoring\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mneg_mean_squared_error\u001b[39m\u001b[38;5;124m'\u001b[39m, cv\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m, n_jobs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 23\u001b[0m \u001b[38;5;66;03m# Обучение модели на тренировочных данных\u001b[39;00m\n\u001b[0;32m---> 24\u001b[0m \u001b[43mgrid_search\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train_processing_result\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;66;03m# Результаты подбора гиперпараметров\u001b[39;00m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mЛу чшие параметры:\u001b[39m\u001b[38;5;124m\"\u001b[39m, grid_search\u001b[38;5;241m.\u001b[39mbest_params_)\n",
"File \u001b[0;32m~/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/base.py:1473\u001b[0m, in \u001b[0;36m_fit_context.<locals>.decorator.<locals>.wrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1466\u001b[0m estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[1;32m 1468\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m 1469\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m 1470\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 1471\u001b[0m )\n\u001b[1;32m 1472\u001b[0m ):\n\u001b[0;32m-> 1473\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/model_selection/_search.py:1019\u001b[0m, in \u001b[0;36mBaseSearchCV.fit\u001b[0;34m(self, X, y, **params)\u001b[0m\n\u001b[1;32m 1013\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_results(\n\u001b[1;32m 1014\u001b[0m all_candidate_params, n_splits, all_out, all_more_results\n\u001b[1;32m 1015\u001b[0m )\n\u001b[1;32m 1017\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m results\n\u001b[0;32m-> 1019\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_search\u001b[49m\u001b[43m(\u001b[49m\u001b[43mevaluate_candidates\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1021\u001b[0m \u001b[38;5;66;03m# multimetric is determined here because in the case of a callable\u001b[39;00m\n\u001b[1;32m 1022\u001b[0m \u001b[38;5;66;03m# self.scoring the return type is only known after calling\u001b[39;00m\n\u001b[1;32m 1023\u001b[0m first_test_score \u001b[38;5;241m=\u001b[39m all_out[\u001b[38;5;241m0\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_scores\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
"File \u001b[0;32m~/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/model_selection/_search.py:1573\u001b[0m, in \u001b[0;36mGridSearchCV._run_search\u001b[0;34m(self, evaluate_candidates)\u001b[0m\n\u001b[1;32m 1571\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_run_search\u001b[39m(\u001b[38;5;28mself\u001b[39m, evaluate_candidates):\n\u001b[1;32m 1572\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Search all candidates in param_grid\"\"\"\u001b[39;00m\n\u001b[0;32m-> 1573\u001b[0m \u001b[43mevaluate_candidates\u001b[49m\u001b[43m(\u001b[49m\u001b[43mParameterGrid\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparam_grid\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/model_selection/_search.py:965\u001b[0m, in \u001b[0;36mBaseSearchCV.fit.<locals>.evaluate_candidates\u001b[0;34m(candidate_params, cv, more_results)\u001b[0m\n\u001b[1;32m 957\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 958\u001b[0m \u001b[38;5;28mprint\u001b[39m(\n\u001b[1;32m 959\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFitting \u001b[39m\u001b[38;5;132;01m{0}\u001b[39;00m\u001b[38;5;124m folds for each of \u001b[39m\u001b[38;5;132;01m{1}\u001b[39;00m\u001b[38;5;124m candidates,\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 960\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m totalling \u001b[39m\u001b[38;5;132;01m{2}\u001b[39;00m\u001b[38;5;124m fits\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 961\u001b[0m n_splits, n_candidates, n_candidates \u001b[38;5;241m*\u001b[39m n_splits\n\u001b[1;32m 962\u001b[0m )\n\u001b[1;32m 963\u001b[0m )\n\u001b[0;32m--> 965\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mparallel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 966\u001b[0m \u001b[43m \u001b[49m\u001b[43mdelayed\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_fit_and_score\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 967\u001b[0m \u001b[43m \u001b[49m\u001b[43mclone\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbase_estimator\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 968\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 969\u001b[0m \u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 970\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 971\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 972\u001b[0m \u001b[43m \u001b[49m\u001b[43mparameters\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparameters\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 973\u001b[0m \u001b[43m \u001b[49m\u001b[43msplit_progress\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43msplit_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_splits\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 974\u001b[0m \u001b[43m \u001b[49m\u001b[43mcandidate_progress\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcand_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_candidates\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 975\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfit_and_score_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 976\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 977\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mcand_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparameters\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43msplit_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mproduct\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 978\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcandidate_params\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u0
"File \u001b[0;32m~/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/utils/parallel.py:74\u001b[0m, in \u001b[0;36mParallel.__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 69\u001b[0m config \u001b[38;5;241m=\u001b[39m get_config()\n\u001b[1;32m 70\u001b[0m iterable_with_config \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 71\u001b[0m (_with_config(delayed_func, config), args, kwargs)\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m delayed_func, args, kwargs \u001b[38;5;129;01min\u001b[39;00m iterable\n\u001b[1;32m 73\u001b[0m )\n\u001b[0;32m---> 74\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43miterable_with_config\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/joblib/parallel.py:2007\u001b[0m, in \u001b[0;36mParallel.__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 2001\u001b[0m \u001b[38;5;66;03m# The first item from the output is blank, but it makes the interpreter\u001b[39;00m\n\u001b[1;32m 2002\u001b[0m \u001b[38;5;66;03m# progress until it enters the Try/Except block of the generator and\u001b[39;00m\n\u001b[1;32m 2003\u001b[0m \u001b[38;5;66;03m# reaches the first `yield` statement. This starts the asynchronous\u001b[39;00m\n\u001b[1;32m 2004\u001b[0m \u001b[38;5;66;03m# dispatch of the tasks to the workers.\u001b[39;00m\n\u001b[1;32m 2005\u001b[0m \u001b[38;5;28mnext\u001b[39m(output)\n\u001b[0;32m-> 2007\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturn_generator \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43moutput\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/joblib/parallel.py:1650\u001b[0m, in \u001b[0;36mParallel._get_outputs\u001b[0;34m(self, iterator, pre_dispatch)\u001b[0m\n\u001b[1;32m 1647\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m\n\u001b[1;32m 1649\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backend\u001b[38;5;241m.\u001b[39mretrieval_context():\n\u001b[0;32m-> 1650\u001b[0m \u001b[38;5;28;01myield from\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_retrieve()\n\u001b[1;32m 1652\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mGeneratorExit\u001b[39;00m:\n\u001b[1;32m 1653\u001b[0m \u001b[38;5;66;03m# The generator has been garbage collected before being fully\u001b[39;00m\n\u001b[1;32m 1654\u001b[0m \u001b[38;5;66;03m# consumed. This aborts the remaining tasks if possible and warn\u001b[39;00m\n\u001b[1;32m 1655\u001b[0m \u001b[38;5;66;03m# the user if necessary.\u001b[39;00m\n\u001b[1;32m 1656\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_exception \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
"File \u001b[0;32m~/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/joblib/parallel.py:1762\u001b[0m, in \u001b[0;36mParallel._retrieve\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1757\u001b[0m \u001b[38;5;66;03m# If the next job is not ready for retrieval yet, we just wait for\u001b[39;00m\n\u001b[1;32m 1758\u001b[0m \u001b[38;5;66;03m# async callbacks to progress.\u001b[39;00m\n\u001b[1;32m 1759\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ((\u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jobs) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m\n\u001b[1;32m 1760\u001b[0m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jobs[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mget_status(\n\u001b[1;32m 1761\u001b[0m timeout\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtimeout) \u001b[38;5;241m==\u001b[39m TASK_PENDING)):\n\u001b[0;32m-> 1762\u001b[0m \u001b[43mtime\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msleep\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0.01\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1763\u001b[0m \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[1;32m 1765\u001b[0m \u001b[38;5;66;03m# We need to be careful: the job list can be filling up as\u001b[39;00m\n\u001b[1;32m 1766\u001b[0m \u001b[38;5;66;03m# we empty it and Python list are not thread-safe by\u001b[39;00m\n\u001b[1;32m 1767\u001b[0m \u001b[38;5;66;03m# default hence the use of the lock\u001b[39;00m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"\n",
"# Применение конвейера к данным\n",
"X_train_processing_result = pipeline_end.fit_transform(X_df_train)\n",
"X_test_processing_result = pipeline_end.transform(X_df_test)\n",
"\n",
"# Создание и настройка модели случайного леса\n",
"model = RandomForestRegressor()\n",
"\n",
"# Установка параметров для поиска по сетке\n",
"param_grid = {\n",
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
" 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью поиска по сетке\n",
"grid_search = GridSearchCV(estimator=model, \n",
" param_grid=param_grid,\n",
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"grid_search.fit(X_train_processing_result, y_train)\n",
"\n",
"# Результаты подбора гиперпараметров\n",
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
"# Меняем знак, так как берем отрицательное значение среднеквадратичной ошибки\n",
"print(\"Лучший результат (MSE):\", -grid_search.best_score_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Сравнение наборов гиперпараметров"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Установка параметров для поиска по сетке для старых значений\n",
"old_param_grid = {\n",
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
" 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью поиска по сетке для старых параметров\n",
"old_grid_search = GridSearchCV(estimator=model, \n",
" param_grid=old_param_grid,\n",
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"old_grid_search.fit(X_train_processing_result, y_train)\n",
"\n",
"# Результаты подбора для старых параметров\n",
"old_best_params = old_grid_search.best_params_\n",
"# Меняем знак, так как берем отрицательное значение MSE\n",
"old_best_mse = -old_grid_search.best_score_\n",
"\n",
"\n",
"# Установка параметров для поиска по сетке для новых значений\n",
"new_param_grid = {\n",
" 'n_estimators': [50],\n",
" 'max_depth': [5],\n",
" 'min_samples_split': [10]\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью поиска по сетке для новых параметров\n",
"new_grid_search = GridSearchCV(estimator=model, \n",
" param_grid=new_param_grid,\n",
" scoring='neg_mean_squared_error', cv=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"new_grid_search.fit(X_train_processing_result, y_train)\n",
"\n",
"# Результаты подбора для новых параметров\n",
"new_best_params = new_grid_search.best_params_\n",
"# Меняем знак, так как берем отрицательное значение MSE\n",
"new_best_mse = -new_grid_search.best_score_\n",
"\n",
"\n",
"# Обучение модели с лучшими параметрами для новых значений\n",
"model_best = RandomForestRegressor(**new_best_params)\n",
"model_best.fit(X_train_processing_result, y_train)\n",
"\n",
"# Прогнозирование на тестовой выборке\n",
"y_pred = model_best.predict(X_test_processing_result)\n",
"\n",
"# Оценка производительности модели\n",
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"\n",
"\n",
"# Вывод результатов\n",
"print(\"Старые параметры:\", old_best_params)\n",
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
"print(\"\\nН о вые параметры:\", new_best_params)\n",
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)\n",
"\n",
"# Обучение модели с лучшими параметрами для старых значений\n",
"model_old = RandomForestRegressor(**old_best_params)\n",
"model_old.fit(X_train_processing_result, y_train)\n",
"\n",
"# Прогнозирование на тестовой выборке для старых параметров\n",
"y_pred_old = model_old.predict(X_test_processing_result)\n",
"\n",
"# Визуализация ошибок\n",
"plt.figure(figsize=(10, 5))\n",
"plt.plot(y_test, label='Реальные значения', marker='o', linestyle='-', color='black')\n",
"plt.plot(y_pred_old, label='Предсказанные значения (старые параметры)', marker='x', linestyle='--', color='blue')\n",
"plt.plot(y_pred, label='Предсказанные значения (новые параметры)', marker='s', linestyle='--', color='orange')\n",
"plt.xlabel('Объекты')\n",
"plt.ylabel('Значения')\n",
"plt.title('Сравнение реальных и предсказанных значений')\n",
"plt.legend()\n",
"plt.show()"
]
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная...
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная.
— Возьмём?
— Конечно!
Она сзади села и в спальник забралась.
Водила сразу:
— Ну–ка, порули пока, я сейчас... — и к ней в спальник.
Через 10 минут выбрался, сел вперёд, закурил:
— Ай, хороша девка!
Напарник:
— Чё, правда хороша? Ну–ка, порули пока... — и к ней в спальник.
Через 10 минут девка вылезает, садится с водилой, закуривает:
— Ай, хорош у тебя напарник!
Водила:
— Чё правда так хорош? Ну–ка, порули...
2024-12-14 00:24:39 +04:00
}
],
"metadata": {
"kernelspec": {
"display_name": "aimenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-12-14 11:55:47 +04:00
"version": "3.12.8"
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная...
Едут дальнобойщик с напарником. Смотрят — девка голосует симпатичная.
— Возьмём?
— Конечно!
Она сзади села и в спальник забралась.
Водила сразу:
— Ну–ка, порули пока, я сейчас... — и к ней в спальник.
Через 10 минут выбрался, сел вперёд, закурил:
— Ай, хороша девка!
Напарник:
— Чё, правда хороша? Ну–ка, порули пока... — и к ней в спальник.
Через 10 минут девка вылезает, садится с водилой, закуривает:
— Ай, хорош у тебя напарник!
Водила:
— Чё правда так хорош? Ну–ка, порули...
2024-12-14 00:24:39 +04:00
}
},
"nbformat": 4,
"nbformat_minor": 2
}