From 044b414ad10e7b1c04c51171ad9afd490d3cdb11 Mon Sep 17 00:00:00 2001 From: olshab Date: Sat, 14 Dec 2024 11:55:47 +0400 Subject: [PATCH] =?UTF-8?q?=D0=92=D1=82=D0=BE=D1=80=D0=B0=D1=8F=20=D0=BC?= =?UTF-8?q?=D0=B8=D1=80=D0=BE=D0=B2=D0=B0=D1=8F=20=D0=B2=D0=BE=D0=B9=D0=BD?= =?UTF-8?q?=D0=B0.=20=D0=9D=D0=B5=D0=BC=D0=B5=D1=86=D0=BA=D0=B8=D0=B9=20?= =?UTF-8?q?=D0=BA=D0=BE=D0=BD=D1=86=D0=BB=D0=B0=D0=B3=D0=B5=D1=80=D1=8C...?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Вторая мировая война. Немецкий концлагерь. Пленных казнят в газовых камерах. Через весь лагерь бежит немецкий офицер, подбегает к шестой камере. Заключенных уже завели в неё и закрывают дверь. Офицер успевает засунуть ногу между дверью и косяком двери, и спрашивает: — Молдаване? — Угу. — Плитку кладете? — Кладем. — Положить плитку в ванной сколько за метр возьмете? — Да, пять рейхсмарок, пожалуй, возьмем. — Давайте за три? — Ногу убери. --- lab_4/lab4.ipynb | 1836 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 1828 insertions(+), 8 deletions(-) diff --git a/lab_4/lab4.ipynb b/lab_4/lab4.ipynb index a904ad4..02b04d6 100644 --- a/lab_4/lab4.ipynb +++ b/lab_4/lab4.ipynb @@ -41,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -218,7 +218,7 @@ "4 Yes Very good 8.0 No No No " ] }, - "execution_count": 1, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -261,7 +261,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -1214,7 +1214,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -1268,7 +1268,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -1336,7 +1336,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -1722,7 +1722,7 @@ "[255836 rows x 38 columns]" ] }, - "execution_count": 7, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -2159,6 +2159,1826 @@ "for metric_name, value in optimized_model_metrics.items():\n", " print(f\"\\t{metric_name}: {value}\")" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Бизнес-цель №2. Задача регрессии\n", + "\n", + "### Описание бизнес-цели\n", + "\n", + "**Цель**: прогнозирование количества дней с плохим физическим здоровьем. Необходимо спрогнозировать количество дней за последний месяц, в течение которых пациент чувствовал себя физически нездоровым (признак `PhysicalHealth`). Эта метрика отражает общий уровень здоровья и может быть полезной для оценки влияния различных факторов на состояние пациента.\n", + "\n", + "### Достижимый уровень качества модели\n", + "\n", + "**Основные метрики для регрессии:**\n", + "\n", + "- **Средняя абсолютная ошибка** (*Mean Absolute Error, MAE*) – показывает среднее абсолютное отклонение между предсказанными и фактическими значениями. Легко интерпретируется, особенно в финансовых данных, где каждая ошибка в долларах имеет значение.\n", + "- **Среднеквадратичная ошибка** (*Mean Squared Error, MSE*) – показывает, насколько отклоняются прогнозы модели от истинных значений в квадрате. Подходит для оценки общего качества модели.\n", + "- **Коэффициент детерминации** (*R²*) – указывает, какую долю дисперсии зависимой переменной объясняет модель. R² варьируется от 0 до 1 (чем ближе к 1, тем лучше).\n", + "\n", + "### Выбор ориентира\n", + "\n", + "В качестве базовой модели для оценки качества предсказаний выбрано использование среднего значения целевого признака `PhysicalHealth` на обучающей выборке. Это простой и интуитивно понятный метод, который служит минимальным ориентиром для сравнения с более сложными моделями. Базовая модель помогает установить начальный уровень ошибок (MAE, MSE) и показатель качества (R²), которые сложные модели должны улучшить, чтобы оправдать своё использование.\n", + "\n", + "### Разбиение набора данных на выборки\n", + "\n", + "Выполним разбиение исходного набора на **обучающую** (80%) и **тестовую** (20%) выборки:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'X_train'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
HeartDiseaseBMISmokingAlcoholDrinkingStrokePhysicalHealthMentalHealthDiffWalkingSexAgeCategoryRaceDiabeticPhysicalActivityGenHealthSleepTimeAsthmaKidneyDiseaseSkinCancer
46650No30.90NoNoNo30.00.0YesFemale70-74WhiteYesNoPoor7.0YesNoYes
305695No23.75NoYesNo0.00.0NoMale45-49WhiteNoYesExcellent6.0YesYesNo
17353No34.70NoNoNo0.02.0NoFemale70-74WhiteNoNoGood7.0NoNoYes
154614No26.37YesNoNo0.00.0YesFemale80 or olderBlackYesNoFair4.0NoNoNo
146811No18.79NoNoNo0.05.0NoFemale18-24OtherNoYesVery good6.0NoNoNo
.........................................................
224078No24.13YesNoNo0.00.0NoFemale45-49WhiteNoNoVery good8.0NoNoNo
14534No22.32YesNoNo0.00.0NoFemale50-54WhiteNoYesExcellent5.0NoNoNo
156850No23.78YesNoNo2.00.0NoFemale35-39BlackNoYesVery good6.0NoNoNo
221285No26.52NoNoNo0.00.0NoFemale65-69HispanicYesYesGood8.0NoNoNo
16625No23.57NoNoNo0.00.0NoMale40-44WhiteNoYesGood7.0NoNoNo
\n", + "

255836 rows × 18 columns

\n", + "
" + ], + "text/plain": [ + " HeartDisease BMI Smoking AlcoholDrinking Stroke PhysicalHealth \\\n", + "46650 No 30.90 No No No 30.0 \n", + "305695 No 23.75 No Yes No 0.0 \n", + "17353 No 34.70 No No No 0.0 \n", + "154614 No 26.37 Yes No No 0.0 \n", + "146811 No 18.79 No No No 0.0 \n", + "... ... ... ... ... ... ... \n", + "224078 No 24.13 Yes No No 0.0 \n", + "14534 No 22.32 Yes No No 0.0 \n", + "156850 No 23.78 Yes No No 2.0 \n", + "221285 No 26.52 No No No 0.0 \n", + "16625 No 23.57 No No No 0.0 \n", + "\n", + " MentalHealth DiffWalking Sex AgeCategory Race Diabetic \\\n", + "46650 0.0 Yes Female 70-74 White Yes \n", + "305695 0.0 No Male 45-49 White No \n", + "17353 2.0 No Female 70-74 White No \n", + "154614 0.0 Yes Female 80 or older Black Yes \n", + "146811 5.0 No Female 18-24 Other No \n", + "... ... ... ... ... ... ... \n", + "224078 0.0 No Female 45-49 White No \n", + "14534 0.0 No Female 50-54 White No \n", + "156850 0.0 No Female 35-39 Black No \n", + "221285 0.0 No Female 65-69 Hispanic Yes \n", + "16625 0.0 No Male 40-44 White No \n", + "\n", + " PhysicalActivity GenHealth SleepTime Asthma KidneyDisease SkinCancer \n", + "46650 No Poor 7.0 Yes No Yes \n", + "305695 Yes Excellent 6.0 Yes Yes No \n", + "17353 No Good 7.0 No No Yes \n", + "154614 No Fair 4.0 No No No \n", + "146811 Yes Very good 6.0 No No No \n", + "... ... ... ... ... ... ... \n", + "224078 No Very good 8.0 No No No \n", + "14534 Yes Excellent 5.0 No No No \n", + "156850 Yes Very good 6.0 No No No \n", + "221285 Yes Good 8.0 No No No \n", + "16625 Yes Good 7.0 No No No \n", + "\n", + "[255836 rows x 18 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'y_train'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
PhysicalHealth
4665030.0
3056950.0
173530.0
1546140.0
1468110.0
......
2240780.0
145340.0
1568502.0
2212850.0
166250.0
\n", + "

255836 rows × 1 columns

\n", + "
" + ], + "text/plain": [ + " PhysicalHealth\n", + "46650 30.0\n", + "305695 0.0\n", + "17353 0.0\n", + "154614 0.0\n", + "146811 0.0\n", + "... ...\n", + "224078 0.0\n", + "14534 0.0\n", + "156850 2.0\n", + "221285 0.0\n", + "16625 0.0\n", + "\n", + "[255836 rows x 1 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'X_test'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
HeartDiseaseBMISmokingAlcoholDrinkingStrokePhysicalHealthMentalHealthDiffWalkingSexAgeCategoryRaceDiabeticPhysicalActivityGenHealthSleepTimeAsthmaKidneyDiseaseSkinCancer
146589No19.45NoNoNo1.00.0YesMale25-29WhiteNoYesGood12.0NoNoNo
216017No26.36NoNoNo0.00.0NoFemale40-44WhiteNoYesVery good7.0NoNoNo
19624No24.59YesNoNo0.00.0YesMale55-59WhiteNoYesGood6.0NoNoNo
65923No23.44YesNoNo0.020.0NoFemale60-64AsianNoYesVery good6.0NoNoNo
63362No31.32NoNoNo0.02.0NoFemale45-49WhiteNoNoVery good6.0YesNoNo
.........................................................
252474No42.37NoNoNo1.05.0NoFemale18-24WhiteNoYesGood8.0YesNoNo
147913No32.08YesNoNo0.00.0NoMale40-44WhiteNoYesExcellent8.0NoNoNo
244674No31.28NoNoNo3.00.0NoFemale50-54WhiteYesYesGood7.0NoNoYes
215373No31.65NoNoNo0.00.0NoMale45-49WhiteNoYesVery good8.0NoNoNo
179461No27.37NoNoNo0.00.0NoMale65-69WhiteYesNoGood7.0NoNoNo
\n", + "

63959 rows × 18 columns

\n", + "
" + ], + "text/plain": [ + " HeartDisease BMI Smoking AlcoholDrinking Stroke PhysicalHealth \\\n", + "146589 No 19.45 No No No 1.0 \n", + "216017 No 26.36 No No No 0.0 \n", + "19624 No 24.59 Yes No No 0.0 \n", + "65923 No 23.44 Yes No No 0.0 \n", + "63362 No 31.32 No No No 0.0 \n", + "... ... ... ... ... ... ... \n", + "252474 No 42.37 No No No 1.0 \n", + "147913 No 32.08 Yes No No 0.0 \n", + "244674 No 31.28 No No No 3.0 \n", + "215373 No 31.65 No No No 0.0 \n", + "179461 No 27.37 No No No 0.0 \n", + "\n", + " MentalHealth DiffWalking Sex AgeCategory Race Diabetic \\\n", + "146589 0.0 Yes Male 25-29 White No \n", + "216017 0.0 No Female 40-44 White No \n", + "19624 0.0 Yes Male 55-59 White No \n", + "65923 20.0 No Female 60-64 Asian No \n", + "63362 2.0 No Female 45-49 White No \n", + "... ... ... ... ... ... ... \n", + "252474 5.0 No Female 18-24 White No \n", + "147913 0.0 No Male 40-44 White No \n", + "244674 0.0 No Female 50-54 White Yes \n", + "215373 0.0 No Male 45-49 White No \n", + "179461 0.0 No Male 65-69 White Yes \n", + "\n", + " PhysicalActivity GenHealth SleepTime Asthma KidneyDisease SkinCancer \n", + "146589 Yes Good 12.0 No No No \n", + "216017 Yes Very good 7.0 No No No \n", + "19624 Yes Good 6.0 No No No \n", + "65923 Yes Very good 6.0 No No No \n", + "63362 No Very good 6.0 Yes No No \n", + "... ... ... ... ... ... ... \n", + "252474 Yes Good 8.0 Yes No No \n", + "147913 Yes Excellent 8.0 No No No \n", + "244674 Yes Good 7.0 No No Yes \n", + "215373 Yes Very good 8.0 No No No \n", + "179461 No Good 7.0 No No No \n", + "\n", + "[63959 rows x 18 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'y_test'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
PhysicalHealth
1465891.0
2160170.0
196240.0
659230.0
633620.0
......
2524741.0
1479130.0
2446743.0
2153730.0
1794610.0
\n", + "

63959 rows × 1 columns

\n", + "
" + ], + "text/plain": [ + " PhysicalHealth\n", + "146589 1.0\n", + "216017 0.0\n", + "19624 0.0\n", + "65923 0.0\n", + "63362 0.0\n", + "... ...\n", + "252474 1.0\n", + "147913 0.0\n", + "244674 3.0\n", + "215373 0.0\n", + "179461 0.0\n", + "\n", + "[63959 rows x 1 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "X_df_train, X_df_val, X_df_test, y_df_train, y_df_val, y_df_test = split_stratified_into_train_val_test(\n", + " df, stratify_colname='PhysicalHealth', frac_train=0.8, frac_val=0, frac_test=0.2, random_state=9\n", + ")\n", + "\n", + "display(\"X_train\", X_df_train)\n", + "display(\"y_train\", y_df_train)\n", + "\n", + "display(\"X_test\", X_df_test)\n", + "display(\"y_test\", y_df_test)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Построим **базовую модель**, описанную выше, и оценим ее метрики *MAE*, *MSE* и *R²*:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Baseline MAE: 5.081172924146543\n", + "Baseline MSE: 63.21384755665578\n", + "Baseline R²: -6.286438036795516e-10\n" + ] + } + ], + "source": [ + "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n", + "\n", + "# Вычисляем предсказания базовой модели (среднее значение целевой переменной)\n", + "baseline_predictions = [y_df_train.mean()] * len(y_df_test) # type: ignore\n", + "\n", + "# Оцениваем базовую модель\n", + "print('Baseline MAE:', mean_absolute_error(y_df_test, baseline_predictions))\n", + "print('Baseline MSE:', mean_squared_error(y_df_test, baseline_predictions))\n", + "print('Baseline R²:', r2_score(y_df_test, baseline_predictions))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Выбор моделей обучения\n", + "\n", + "Для обучения были выбраны следующие модели:\n", + "\n", + "1. **Случайный лес** (*Random Forest*): Ансамблевая модель, которая использует множество решающих деревьев. Она хорошо справляется с нелинейными зависимостями и шумом в данных, а также обладает устойчивостью к переобучению.\n", + "2. **Линейная регрессия** (*Linear Regression*): Простая модель, предполагающая линейную зависимость между признаками и целевой переменной. Она быстро обучается и предоставляет легкую интерпретацию результатов.\n", + "3. **Градиентный бустинг** (*Gradient Boosting*): Мощная модель, создающая ансамбль деревьев, которые корректируют ошибки предыдущих. Эта модель эффективна для сложных наборов данных и обеспечивает высокую точность предсказаний." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Использование конвейера на тренировочных данных\n", + "\n", + "Конвейер уже был построен при решении задачи классификации. Применяем готовый конвейер:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
BMIPhysicalHealthMentalHealthSleepTimeHeartDisease_YesSmoking_YesAlcoholDrinking_YesStroke_YesDiffWalking_YesSex_Male...Diabetic_YesDiabetic_Yes (during pregnancy)PhysicalActivity_YesGenHealth_FairGenHealth_GoodGenHealth_PoorGenHealth_Very goodAsthma_YesKidneyDisease_YesSkinCancer_Yes
00.4049943.349099-0.490224-0.0684380.00.00.00.01.00.0...1.00.00.00.00.01.00.01.00.01.0
1-0.718814-0.424073-0.490224-0.7655080.00.01.00.00.01.0...0.00.01.00.00.00.00.01.01.00.0
21.002262-0.424073-0.238838-0.0684380.00.00.00.00.00.0...0.00.00.00.01.00.00.00.00.01.0
3-0.307013-0.424073-0.490224-2.1596460.01.00.00.01.00.0...1.00.00.01.00.00.00.00.00.00.0
4-1.498407-0.4240730.138242-0.7655080.00.00.00.00.00.0...0.00.01.00.00.00.01.00.00.00.0
51.2568871.336741-0.490224-0.7655080.00.00.00.01.01.0...1.00.00.01.00.00.00.01.00.00.0
6-0.506627-0.424073-0.490224-0.0684380.00.00.00.00.01.0...0.00.01.00.00.00.01.00.00.00.0
7-1.690161-0.424073-0.4902240.6286310.00.00.00.00.00.0...0.00.01.00.00.00.01.00.00.00.0
8-0.030384-0.424073-0.4902240.6286310.01.01.00.00.00.0...0.00.01.00.00.00.01.00.00.00.0
9-0.1671273.3490991.395175-0.0684380.00.00.00.01.01.0...0.00.00.01.00.00.00.00.00.00.0
\n", + "

10 rows × 38 columns

\n", + "
" + ], + "text/plain": [ + " BMI PhysicalHealth MentalHealth SleepTime HeartDisease_Yes \\\n", + "0 0.404994 3.349099 -0.490224 -0.068438 0.0 \n", + "1 -0.718814 -0.424073 -0.490224 -0.765508 0.0 \n", + "2 1.002262 -0.424073 -0.238838 -0.068438 0.0 \n", + "3 -0.307013 -0.424073 -0.490224 -2.159646 0.0 \n", + "4 -1.498407 -0.424073 0.138242 -0.765508 0.0 \n", + "5 1.256887 1.336741 -0.490224 -0.765508 0.0 \n", + "6 -0.506627 -0.424073 -0.490224 -0.068438 0.0 \n", + "7 -1.690161 -0.424073 -0.490224 0.628631 0.0 \n", + "8 -0.030384 -0.424073 -0.490224 0.628631 0.0 \n", + "9 -0.167127 3.349099 1.395175 -0.068438 0.0 \n", + "\n", + " Smoking_Yes AlcoholDrinking_Yes Stroke_Yes DiffWalking_Yes Sex_Male \\\n", + "0 0.0 0.0 0.0 1.0 0.0 \n", + "1 0.0 1.0 0.0 0.0 1.0 \n", + "2 0.0 0.0 0.0 0.0 0.0 \n", + "3 1.0 0.0 0.0 1.0 0.0 \n", + "4 0.0 0.0 0.0 0.0 0.0 \n", + "5 0.0 0.0 0.0 1.0 1.0 \n", + "6 0.0 0.0 0.0 0.0 1.0 \n", + "7 0.0 0.0 0.0 0.0 0.0 \n", + "8 1.0 1.0 0.0 0.0 0.0 \n", + "9 0.0 0.0 0.0 1.0 1.0 \n", + "\n", + " ... Diabetic_Yes Diabetic_Yes (during pregnancy) PhysicalActivity_Yes \\\n", + "0 ... 1.0 0.0 0.0 \n", + "1 ... 0.0 0.0 1.0 \n", + "2 ... 0.0 0.0 0.0 \n", + "3 ... 1.0 0.0 0.0 \n", + "4 ... 0.0 0.0 1.0 \n", + "5 ... 1.0 0.0 0.0 \n", + "6 ... 0.0 0.0 1.0 \n", + "7 ... 0.0 0.0 1.0 \n", + "8 ... 0.0 0.0 1.0 \n", + "9 ... 0.0 0.0 0.0 \n", + "\n", + " GenHealth_Fair GenHealth_Good GenHealth_Poor GenHealth_Very good \\\n", + "0 0.0 0.0 1.0 0.0 \n", + "1 0.0 0.0 0.0 0.0 \n", + "2 0.0 1.0 0.0 0.0 \n", + "3 1.0 0.0 0.0 0.0 \n", + "4 0.0 0.0 0.0 1.0 \n", + "5 1.0 0.0 0.0 0.0 \n", + "6 0.0 0.0 0.0 1.0 \n", + "7 0.0 0.0 0.0 1.0 \n", + "8 0.0 0.0 0.0 1.0 \n", + "9 1.0 0.0 0.0 0.0 \n", + "\n", + " Asthma_Yes KidneyDisease_Yes SkinCancer_Yes \n", + "0 1.0 0.0 1.0 \n", + "1 1.0 1.0 0.0 \n", + "2 0.0 0.0 1.0 \n", + "3 0.0 0.0 0.0 \n", + "4 0.0 0.0 0.0 \n", + "5 1.0 0.0 0.0 \n", + "6 0.0 0.0 0.0 \n", + "7 0.0 0.0 0.0 \n", + "8 0.0 0.0 0.0 \n", + "9 0.0 0.0 0.0 \n", + "\n", + "[10 rows x 38 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Применение конвейера\n", + "preprocessing_result = pipeline_end.fit_transform(X_df_train)\n", + "preprocessed_df = pd.DataFrame(\n", + " preprocessing_result,\n", + " columns=pipeline_end.get_feature_names_out(),\n", + ")\n", + "\n", + "preprocessed_df.head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Обучение моделей" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n", + " return fit_method(estimator, *args, **kwargs)\n", + "/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n", + " return fit_method(estimator, *args, **kwargs)\n", + "/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n", + " return fit_method(estimator, *args, **kwargs)\n", + "/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n", + " return fit_method(estimator, *args, **kwargs)\n", + "/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n", + " return fit_method(estimator, *args, **kwargs)\n", + "/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/ensemble/_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", + " y = column_or_1d(y, warn=True) # TODO: Is this still required?\n", + "/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/ensemble/_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", + " y = column_or_1d(y, warn=True) # TODO: Is this still required?\n", + "/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/ensemble/_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", + " y = column_or_1d(y, warn=True) # TODO: Is this still required?\n", + "/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/ensemble/_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", + " y = column_or_1d(y, warn=True) # TODO: Is this still required?\n", + "/home/oleg/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/ensemble/_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", + " y = column_or_1d(y, warn=True) # TODO: Is this still required?\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Модель: Random Forest\n", + "\tmean_score: 1.0\n", + "\tstd_dev: 0.0\n", + "\n", + "Модель: Linear Regression\n", + "\tmean_score: 1.0\n", + "\tstd_dev: 0.0\n", + "\n", + "Модель: Gradient Boosting\n", + "\tmean_score: 0.9999999324559854\n", + "\tstd_dev: 1.916515351322297e-08\n", + "\n" + ] + } + ], + "source": [ + "from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor\n", + "from sklearn.linear_model import LinearRegression\n", + "from sklearn.model_selection import cross_val_score\n", + "\n", + "\n", + "# Обучить модели\n", + "def train_models(X, y, models):\n", + " results = {}\n", + " \n", + " for model_name, model in models.items():\n", + " # Создание конвейера для текущей модели\n", + " model_pipeline = Pipeline(\n", + " [\n", + " (\"features_preprocessing\", features_preprocessing),\n", + " (\"model\", model)\n", + " ]\n", + " )\n", + " \n", + " # Обучаем модель и вычисляем кросс-валидацию\n", + " scores = cross_val_score(model_pipeline, X, y, cv=5) # 5-кратная кросс-валидация\n", + " \n", + " # Вычисление метрик для текущей модели\n", + " metrics_dict = {\n", + " \"mean_score\": scores.mean(),\n", + " \"std_dev\": scores.std()\n", + " }\n", + " \n", + " # Сохранениерезультатов\n", + " results[model_name] = metrics_dict\n", + " \n", + " return results\n", + "\n", + "\n", + "# Выбранные модели для регрессии\n", + "models_regression = {\n", + " \"Random Forest\": RandomForestRegressor(),\n", + " \"Linear Regression\": LinearRegression(),\n", + " \"Gradient Boosting\": GradientBoostingRegressor(),\n", + "}\n", + "\n", + "results = train_models(X_df_train, y_df_train, models_regression)\n", + "\n", + "# Вывод результатов\n", + "for model_name, metrics_dict in results.items():\n", + " print(f\"Модель: {model_name}\")\n", + " for metric_name, value in metrics_dict.items():\n", + " print(f\"\\t{metric_name}: {value}\")\n", + " print()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Среднее значение и стандартное отклонение:**\n", + "\n", + "1. **Случайный лес (Random Forest)**:\n", + " - Метрики:\n", + " * Средний балл: 1.0\n", + " * Стандартное отклонение: 0.0\n", + " - ***Вывод***: модель случайного леса продемонстрировала идеальный результат с точностью 1.0 и без колебаний в результатах (ноль стандартного отклонения). Это может свидетельствовать о том, что модель хорошо справилась с задачей и достаточно стабильна. Однако стоит учитывать, что подобные результаты могут быть признаком переобучения, так как оценка проводилась на обучающих данных.\n", + "\n", + "1. **Линейная регрессия (Linear Regression)**:\n", + " - Метрики:\n", + " * Средний балл: 1.0\n", + " * Стандартное отклонение: 0.0\n", + " - ***Вывод***: линейная регрессия также показала идеальный результат с точностью 1.0 и нулевым отклонением. Это говорит о том, что линейная модель очень хорошо подошла для данной задачи, но также важно проверить, не произошел ли случайный подбор данных, что привело к переобучению модели.\n", + "\n", + "1. **Градиентный бустинг (Gradient Boosting)**:\n", + " - Метрики:\n", + " * Средний балл: 0.999\n", + " * Стандартное отклонение: 0.0\n", + " - ***Вывод***: Градиентный бустинг показал практически идеальный результат, также с нулевым стандартным отклонением. Это подтверждает высокую стабильность модели, но она немного уступает случайному лесу по точности. В целом, модель демонстрирует отличные результаты, что может указывать на ее высокую способность к обобщению." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Расчет метрик" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Модель: Random Forest\n", + "\tMAE_train: 0.0\n", + "\tMAE_test: 0.0\n", + "\tMSE_train: 0.0\n", + "\tMSE_test: 0.0\n", + "\tR2_train: 1.0\n", + "\tR2_test: 1.0\n", + "\tSTD_train: 0.0\n", + "\tSTD_test: 0.0\n", + "\n", + "Модель: Linear Regression\n", + "\tMAE_train: 1.194371035153155e-14\n", + "\tMAE_test: 1.1909445826766327e-14\n", + "\tMSE_train: 1.901081790225907e-28\n", + "\tMSE_test: 1.8951168132152725e-28\n", + "\tR2_train: 1.0\n", + "\tR2_test: 1.0\n", + "\tSTD_train: 9.090236366489451e-15\n", + "\tSTD_test: 9.090299369484082e-15\n", + "\n", + "Модель: Gradient Boosting\n", + "\tMAE_train: 0.00030786687422158955\n", + "\tMAE_test: 0.00030731279564540775\n", + "\tMSE_train: 4.381537207145074e-06\n", + "\tMSE_test: 4.342684206551716e-06\n", + "\tR2_train: 0.9999999306897712\n", + "\tR2_test: 0.9999999313016945\n", + "\tSTD_train: 0.0020932121744211872\n", + "\tSTD_test: 0.0020839106254309228\n", + "\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "from sklearn import metrics\n", + "\n", + "\n", + "# Оценка качества различных моделей на основе метрик\n", + "def evaluate_models(models,\n", + " pipeline_end, \n", + " X_train, y_train, \n", + " X_test, y_test):\n", + " results = {}\n", + " \n", + " for model_name, model in models.items():\n", + " # Создание конвейера для текущей модели\n", + " model_pipeline = Pipeline(\n", + " [\n", + " (\"pipeline\", pipeline_end), \n", + " (\"model\", model),\n", + " ]\n", + " )\n", + " \n", + " # Обучение текущей модели\n", + " model_pipeline.fit(X_train, y_train)\n", + "\n", + " # Предсказание для обучающей и тестовой выборки\n", + " y_train_predict = model_pipeline.predict(X_train)\n", + " y_test_predict = model_pipeline.predict(X_test)\n", + "\n", + " # Вычисление метрик для текущей модели\n", + " metrics_dict = {\n", + " \"MAE_train\": metrics.mean_absolute_error(y_train, y_train_predict),\n", + " \"MAE_test\": metrics.mean_absolute_error(y_test, y_test_predict),\n", + " \"MSE_train\": metrics.mean_squared_error(y_train, y_train_predict),\n", + " \"MSE_test\": metrics.mean_squared_error(y_test, y_test_predict),\n", + " \"R2_train\": metrics.r2_score(y_train, y_train_predict),\n", + " \"R2_test\": metrics.r2_score(y_test, y_test_predict),\n", + " \"STD_train\": np.std(y_train - y_train_predict),\n", + " \"STD_test\": np.std(y_test - y_test_predict),\n", + " }\n", + "\n", + " # Сохранение результатов\n", + " results[model_name] = metrics_dict\n", + " \n", + " return results\n", + "\n", + "\n", + "y_train = np.ravel(y_df_train) \n", + "y_test = np.ravel(y_df_test) \n", + "\n", + "results = evaluate_models(models_regression,\n", + " pipeline_end,\n", + " X_df_train, y_train,\n", + " X_df_test, y_test)\n", + "\n", + "# Вывод результатов\n", + "for model_name, metrics_dict in results.items():\n", + " print(f\"Модель: {model_name}\")\n", + " for metric_name, value in metrics_dict.items():\n", + " print(f\"\\t{metric_name}: {value}\")\n", + " print()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Результаты:**\n", + "\n", + "1. **Случайный лес (Random Forest)**:\n", + " - Метрики: \n", + " * MAE (обучение): 0.0\n", + " * MAE (тест): 0.0\n", + " * MSE (обучение): 0.0\n", + " * MSE (тест): 0.0\n", + " * R² (обучение): 1.0\n", + " * R² (тест): 1.0\n", + " * STD (обучение): 0.0\n", + " * STD (тест): 0.0\n", + " - ***Вывод***: модель случайного леса продемонстрировала абсолютно идеальные результаты как на обучающих, так и на тестовых данных, с нулевыми значениями ошибок и максимально возможным значением R². Эти показатели указывают на крайне высокую точность модели и её способность к обобщению. Однако, важно проверить на других наборах данных, так как такие результаты могут быть признаком переобучения, если тестовый набор данных не был независим от обучающего.\n", + "\n", + "1. **Линейная регрессия (Linear Regression)**:\n", + " - Метрики: \n", + " * MAE (обучение): 1.19e-14\n", + " * MAE (тест): 1.19e-14\n", + " * MSE (обучение): 1.90e-28\n", + " * MSE (тест): 1.86e-28\n", + " * R² (обучение): 1.0\n", + " * R² (тест): 1.0\n", + " * STD (обучение): 9.09e-15\n", + " * STD (тест): 9.09e-15\n", + " - ***Вывод***: линейная регрессия также показала выдающиеся результаты с нулевыми ошибками и максимальным R², что может свидетельствовать о её идеальной подгонке под данные. Однако крайне низкие значения ошибок и стандартного отклонения могут указывать на переобучение модели, особенно если она идеально подогнана под обучающие данные. Это значит, что такая модель может не работать хорошо на новых данных, если она слишком специфична для текущего набора.\n", + "\n", + "1. **Градиентный бустинг (Gradient Boosting)**:\n", + " - Метрики: \n", + " * MAE (обучение): 0.0\n", + " * MAE (тест): 0.0\n", + " * MSE (обучение): 4.38e-06\n", + " * MSE (тест): 4.34e-06\n", + " * R² (обучение): 1.0\n", + " * R² (тест): 1.0\n", + " * STD (обучение): 0.002\n", + " * STD (тест): 0.002\n", + " - ***Вывод***: градиентный бустинг показал отличные результаты, с минимальными ошибками и максимально возможным R², что указывает на высокую точность модели. Небольшое стандартное отклонение (около 0.002) свидетельствует о стабильности модели и её устойчивости к изменениям в данных. Это хороший показатель для модели, так как она не перегружена шумом и демонстрирует надежность на тестовых данных." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Подбор гиперпараметров" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fitting 3 folds for each of 36 candidates, totalling 108 fits\n", + "[CV] END max_depth=None, min_samples_split=5, n_estimators=50; total time= 1.5min\n", + "[CV] END max_depth=None, min_samples_split=5, n_estimators=50; total time= 1.5min\n", + "[CV] END max_depth=None, min_samples_split=2, n_estimators=50; total time= 1.5min\n", + "[CV] END max_depth=None, min_samples_split=2, n_estimators=50; total time= 1.6min\n", + "[CV] END max_depth=None, min_samples_split=2, n_estimators=50; total time= 1.6min\n", + "[CV] END max_depth=None, min_samples_split=5, n_estimators=50; total time= 1.6min\n", + "[CV] END max_depth=None, min_samples_split=5, n_estimators=100; total time= 3.0min\n", + "[CV] END max_depth=None, min_samples_split=2, n_estimators=100; total time= 3.0min\n", + "[CV] END max_depth=None, min_samples_split=2, n_estimators=100; total time= 3.1min\n", + "[CV] END max_depth=None, min_samples_split=10, n_estimators=50; total time= 1.5min\n", + "[CV] END max_depth=None, min_samples_split=2, n_estimators=100; total time= 3.1min\n", + "[CV] END max_depth=None, min_samples_split=5, n_estimators=100; total time= 3.1min\n", + "[CV] END max_depth=None, min_samples_split=5, n_estimators=100; total time= 3.2min\n", + "[CV] END max_depth=None, min_samples_split=10, n_estimators=50; total time= 1.6min\n", + "[CV] END max_depth=None, min_samples_split=10, n_estimators=50; total time= 1.6min\n", + "[CV] END .max_depth=10, min_samples_split=2, n_estimators=50; total time= 1.7min\n", + "[CV] END max_depth=None, min_samples_split=10, n_estimators=100; total time= 3.2min\n", + "[CV] END .max_depth=10, min_samples_split=2, n_estimators=50; total time= 1.7min\n", + "[CV] END .max_depth=10, min_samples_split=2, n_estimators=50; total time= 1.8min\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[11], line 24\u001b[0m\n\u001b[1;32m 19\u001b[0m grid_search \u001b[38;5;241m=\u001b[39m GridSearchCV(estimator\u001b[38;5;241m=\u001b[39mmodel, \n\u001b[1;32m 20\u001b[0m param_grid\u001b[38;5;241m=\u001b[39mparam_grid,\n\u001b[1;32m 21\u001b[0m scoring\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mneg_mean_squared_error\u001b[39m\u001b[38;5;124m'\u001b[39m, cv\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m, n_jobs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 23\u001b[0m \u001b[38;5;66;03m# Обучение модели на тренировочных данных\u001b[39;00m\n\u001b[0;32m---> 24\u001b[0m \u001b[43mgrid_search\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train_processing_result\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;66;03m# Результаты подбора гиперпараметров\u001b[39;00m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mЛучшие параметры:\u001b[39m\u001b[38;5;124m\"\u001b[39m, grid_search\u001b[38;5;241m.\u001b[39mbest_params_)\n", + "File \u001b[0;32m~/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/base.py:1473\u001b[0m, in \u001b[0;36m_fit_context..decorator..wrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1466\u001b[0m estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[1;32m 1468\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m 1469\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m 1470\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 1471\u001b[0m )\n\u001b[1;32m 1472\u001b[0m ):\n\u001b[0;32m-> 1473\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/model_selection/_search.py:1019\u001b[0m, in \u001b[0;36mBaseSearchCV.fit\u001b[0;34m(self, X, y, **params)\u001b[0m\n\u001b[1;32m 1013\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_results(\n\u001b[1;32m 1014\u001b[0m all_candidate_params, n_splits, all_out, all_more_results\n\u001b[1;32m 1015\u001b[0m )\n\u001b[1;32m 1017\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m results\n\u001b[0;32m-> 1019\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_search\u001b[49m\u001b[43m(\u001b[49m\u001b[43mevaluate_candidates\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1021\u001b[0m \u001b[38;5;66;03m# multimetric is determined here because in the case of a callable\u001b[39;00m\n\u001b[1;32m 1022\u001b[0m \u001b[38;5;66;03m# self.scoring the return type is only known after calling\u001b[39;00m\n\u001b[1;32m 1023\u001b[0m first_test_score \u001b[38;5;241m=\u001b[39m all_out[\u001b[38;5;241m0\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_scores\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n", + "File \u001b[0;32m~/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/model_selection/_search.py:1573\u001b[0m, in \u001b[0;36mGridSearchCV._run_search\u001b[0;34m(self, evaluate_candidates)\u001b[0m\n\u001b[1;32m 1571\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_run_search\u001b[39m(\u001b[38;5;28mself\u001b[39m, evaluate_candidates):\n\u001b[1;32m 1572\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Search all candidates in param_grid\"\"\"\u001b[39;00m\n\u001b[0;32m-> 1573\u001b[0m \u001b[43mevaluate_candidates\u001b[49m\u001b[43m(\u001b[49m\u001b[43mParameterGrid\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparam_grid\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/model_selection/_search.py:965\u001b[0m, in \u001b[0;36mBaseSearchCV.fit..evaluate_candidates\u001b[0;34m(candidate_params, cv, more_results)\u001b[0m\n\u001b[1;32m 957\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 958\u001b[0m \u001b[38;5;28mprint\u001b[39m(\n\u001b[1;32m 959\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFitting \u001b[39m\u001b[38;5;132;01m{0}\u001b[39;00m\u001b[38;5;124m folds for each of \u001b[39m\u001b[38;5;132;01m{1}\u001b[39;00m\u001b[38;5;124m candidates,\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 960\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m totalling \u001b[39m\u001b[38;5;132;01m{2}\u001b[39;00m\u001b[38;5;124m fits\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 961\u001b[0m n_splits, n_candidates, n_candidates \u001b[38;5;241m*\u001b[39m n_splits\n\u001b[1;32m 962\u001b[0m )\n\u001b[1;32m 963\u001b[0m )\n\u001b[0;32m--> 965\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mparallel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 966\u001b[0m \u001b[43m \u001b[49m\u001b[43mdelayed\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_fit_and_score\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 967\u001b[0m \u001b[43m \u001b[49m\u001b[43mclone\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbase_estimator\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 968\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 969\u001b[0m \u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 970\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 971\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 972\u001b[0m \u001b[43m \u001b[49m\u001b[43mparameters\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparameters\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 973\u001b[0m \u001b[43m \u001b[49m\u001b[43msplit_progress\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43msplit_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_splits\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 974\u001b[0m \u001b[43m \u001b[49m\u001b[43mcandidate_progress\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcand_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_candidates\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 975\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfit_and_score_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 976\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 977\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mcand_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparameters\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43msplit_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mproduct\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 978\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcandidate_params\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 979\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mrouted_params\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplitter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 980\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 981\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 983\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(out) \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 984\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 985\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo fits were performed. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 986\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWas the CV iterator empty? \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 987\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWere there no candidates?\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 988\u001b[0m )\n", + "File \u001b[0;32m~/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/sklearn/utils/parallel.py:74\u001b[0m, in \u001b[0;36mParallel.__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 69\u001b[0m config \u001b[38;5;241m=\u001b[39m get_config()\n\u001b[1;32m 70\u001b[0m iterable_with_config \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 71\u001b[0m (_with_config(delayed_func, config), args, kwargs)\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m delayed_func, args, kwargs \u001b[38;5;129;01min\u001b[39;00m iterable\n\u001b[1;32m 73\u001b[0m )\n\u001b[0;32m---> 74\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43miterable_with_config\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/joblib/parallel.py:2007\u001b[0m, in \u001b[0;36mParallel.__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 2001\u001b[0m \u001b[38;5;66;03m# The first item from the output is blank, but it makes the interpreter\u001b[39;00m\n\u001b[1;32m 2002\u001b[0m \u001b[38;5;66;03m# progress until it enters the Try/Except block of the generator and\u001b[39;00m\n\u001b[1;32m 2003\u001b[0m \u001b[38;5;66;03m# reaches the first `yield` statement. This starts the asynchronous\u001b[39;00m\n\u001b[1;32m 2004\u001b[0m \u001b[38;5;66;03m# dispatch of the tasks to the workers.\u001b[39;00m\n\u001b[1;32m 2005\u001b[0m \u001b[38;5;28mnext\u001b[39m(output)\n\u001b[0;32m-> 2007\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturn_generator \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43moutput\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/joblib/parallel.py:1650\u001b[0m, in \u001b[0;36mParallel._get_outputs\u001b[0;34m(self, iterator, pre_dispatch)\u001b[0m\n\u001b[1;32m 1647\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m\n\u001b[1;32m 1649\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backend\u001b[38;5;241m.\u001b[39mretrieval_context():\n\u001b[0;32m-> 1650\u001b[0m \u001b[38;5;28;01myield from\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_retrieve()\n\u001b[1;32m 1652\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mGeneratorExit\u001b[39;00m:\n\u001b[1;32m 1653\u001b[0m \u001b[38;5;66;03m# The generator has been garbage collected before being fully\u001b[39;00m\n\u001b[1;32m 1654\u001b[0m \u001b[38;5;66;03m# consumed. This aborts the remaining tasks if possible and warn\u001b[39;00m\n\u001b[1;32m 1655\u001b[0m \u001b[38;5;66;03m# the user if necessary.\u001b[39;00m\n\u001b[1;32m 1656\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_exception \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n", + "File \u001b[0;32m~/aim_labs/lab_4/aimenv/lib/python3.12/site-packages/joblib/parallel.py:1762\u001b[0m, in \u001b[0;36mParallel._retrieve\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1757\u001b[0m \u001b[38;5;66;03m# If the next job is not ready for retrieval yet, we just wait for\u001b[39;00m\n\u001b[1;32m 1758\u001b[0m \u001b[38;5;66;03m# async callbacks to progress.\u001b[39;00m\n\u001b[1;32m 1759\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ((\u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jobs) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m\n\u001b[1;32m 1760\u001b[0m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jobs[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mget_status(\n\u001b[1;32m 1761\u001b[0m timeout\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtimeout) \u001b[38;5;241m==\u001b[39m TASK_PENDING)):\n\u001b[0;32m-> 1762\u001b[0m \u001b[43mtime\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msleep\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0.01\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1763\u001b[0m \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[1;32m 1765\u001b[0m \u001b[38;5;66;03m# We need to be careful: the job list can be filling up as\u001b[39;00m\n\u001b[1;32m 1766\u001b[0m \u001b[38;5;66;03m# we empty it and Python list are not thread-safe by\u001b[39;00m\n\u001b[1;32m 1767\u001b[0m \u001b[38;5;66;03m# default hence the use of the lock\u001b[39;00m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "from sklearn.model_selection import GridSearchCV\n", + "\n", + "\n", + "# Применение конвейера к данным\n", + "X_train_processing_result = pipeline_end.fit_transform(X_df_train)\n", + "X_test_processing_result = pipeline_end.transform(X_df_test)\n", + "\n", + "# Создание и настройка модели случайного леса\n", + "model = RandomForestRegressor()\n", + "\n", + "# Установка параметров для поиска по сетке\n", + "param_grid = {\n", + " 'n_estimators': [50, 100, 200], # Количество деревьев\n", + " 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n", + " 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n", + "}\n", + "\n", + "# Подбор гиперпараметров с помощью поиска по сетке\n", + "grid_search = GridSearchCV(estimator=model, \n", + " param_grid=param_grid,\n", + " scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n", + "\n", + "# Обучение модели на тренировочных данных\n", + "grid_search.fit(X_train_processing_result, y_train)\n", + "\n", + "# Результаты подбора гиперпараметров\n", + "print(\"Лучшие параметры:\", grid_search.best_params_)\n", + "# Меняем знак, так как берем отрицательное значение среднеквадратичной ошибки\n", + "print(\"Лучший результат (MSE):\", -grid_search.best_score_)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Сравнение наборов гиперпараметров" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Установка параметров для поиска по сетке для старых значений\n", + "old_param_grid = {\n", + " 'n_estimators': [50, 100, 200], # Количество деревьев\n", + " 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n", + " 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n", + "}\n", + "\n", + "# Подбор гиперпараметров с помощью поиска по сетке для старых параметров\n", + "old_grid_search = GridSearchCV(estimator=model, \n", + " param_grid=old_param_grid,\n", + " scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n", + "\n", + "# Обучение модели на тренировочных данных\n", + "old_grid_search.fit(X_train_processing_result, y_train)\n", + "\n", + "# Результаты подбора для старых параметров\n", + "old_best_params = old_grid_search.best_params_\n", + "# Меняем знак, так как берем отрицательное значение MSE\n", + "old_best_mse = -old_grid_search.best_score_\n", + "\n", + "\n", + "# Установка параметров для поиска по сетке для новых значений\n", + "new_param_grid = {\n", + " 'n_estimators': [50],\n", + " 'max_depth': [5],\n", + " 'min_samples_split': [10]\n", + "}\n", + "\n", + "# Подбор гиперпараметров с помощью поиска по сетке для новых параметров\n", + "new_grid_search = GridSearchCV(estimator=model, \n", + " param_grid=new_param_grid,\n", + " scoring='neg_mean_squared_error', cv=2)\n", + "\n", + "# Обучение модели на тренировочных данных\n", + "new_grid_search.fit(X_train_processing_result, y_train)\n", + "\n", + "# Результаты подбора для новых параметров\n", + "new_best_params = new_grid_search.best_params_\n", + "# Меняем знак, так как берем отрицательное значение MSE\n", + "new_best_mse = -new_grid_search.best_score_\n", + "\n", + "\n", + "# Обучение модели с лучшими параметрами для новых значений\n", + "model_best = RandomForestRegressor(**new_best_params)\n", + "model_best.fit(X_train_processing_result, y_train)\n", + "\n", + "# Прогнозирование на тестовой выборке\n", + "y_pred = model_best.predict(X_test_processing_result)\n", + "\n", + "# Оценка производительности модели\n", + "mse = metrics.mean_squared_error(y_test, y_pred)\n", + "rmse = np.sqrt(mse)\n", + "\n", + "\n", + "# Вывод результатов\n", + "print(\"Старые параметры:\", old_best_params)\n", + "print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n", + "print(\"\\nНовые параметры:\", new_best_params)\n", + "print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n", + "print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n", + "print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)\n", + "\n", + "# Обучение модели с лучшими параметрами для старых значений\n", + "model_old = RandomForestRegressor(**old_best_params)\n", + "model_old.fit(X_train_processing_result, y_train)\n", + "\n", + "# Прогнозирование на тестовой выборке для старых параметров\n", + "y_pred_old = model_old.predict(X_test_processing_result)\n", + "\n", + "# Визуализация ошибок\n", + "plt.figure(figsize=(10, 5))\n", + "plt.plot(y_test, label='Реальные значения', marker='o', linestyle='-', color='black')\n", + "plt.plot(y_pred_old, label='Предсказанные значения (старые параметры)', marker='x', linestyle='--', color='blue')\n", + "plt.plot(y_pred, label='Предсказанные значения (новые параметры)', marker='s', linestyle='--', color='orange')\n", + "plt.xlabel('Объекты')\n", + "plt.ylabel('Значения')\n", + "plt.title('Сравнение реальных и предсказанных значений')\n", + "plt.legend()\n", + "plt.show()" + ] } ], "metadata": { @@ -2177,7 +3997,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.12.8" } }, "nbformat": 4,