Files
AIM-PIbd-31-Anisin-R-S/lab_4/lab4.ipynb
2024-11-30 01:24:37 +04:00

4576 lines
529 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Лабораторная 4\n",
"Датасет: Набор данных для анализа и прогнозирования сердечного приступа"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['HeartDisease', 'BMI', 'Smoking', 'AlcoholDrinking', 'Stroke',\n",
" 'PhysicalHealth', 'MentalHealth', 'DiffWalking', 'Sex', 'AgeCategory',\n",
" 'Race', 'Diabetic', 'PhysicalActivity', 'GenHealth', 'SleepTime',\n",
" 'Asthma', 'KidneyDisease', 'SkinCancer'],\n",
" dtype='object')\n"
]
}
],
"source": [
"import pandas as pd\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"df = pd.read_csv(\".//static//csv//heart_2020_cleaned.csv\")\n",
"print(df.columns)\n",
"map_heart_disease_to_int = {'No': 0, 'Yes': 1}\n",
"\n",
"df['Stroke'] = df['Stroke'].map(map_heart_disease_to_int).astype('int32')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Классификация"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Бизнес цель 1: \n",
"Предсказание сердечного приступа (Stroke) на основе других факторов."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формируем выборки"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" <th>BMI</th>\n",
" <th>Smoking</th>\n",
" <th>AlcoholDrinking</th>\n",
" <th>Stroke</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>DiffWalking</th>\n",
" <th>Sex</th>\n",
" <th>AgeCategory</th>\n",
" <th>Race</th>\n",
" <th>Diabetic</th>\n",
" <th>PhysicalActivity</th>\n",
" <th>GenHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>Asthma</th>\n",
" <th>KidneyDisease</th>\n",
" <th>SkinCancer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>203716</th>\n",
" <td>No</td>\n",
" <td>30.99</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>70-74</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Fair</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>139550</th>\n",
" <td>No</td>\n",
" <td>32.61</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>5.0</td>\n",
" <td>10.0</td>\n",
" <td>Yes</td>\n",
" <td>Female</td>\n",
" <td>65-69</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Good</td>\n",
" <td>5.0</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>314326</th>\n",
" <td>No</td>\n",
" <td>23.78</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>55-59</td>\n",
" <td>Other</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>79716</th>\n",
" <td>No</td>\n",
" <td>30.38</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>1.0</td>\n",
" <td>30.0</td>\n",
" <td>Yes</td>\n",
" <td>Female</td>\n",
" <td>80 or older</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23944</th>\n",
" <td>Yes</td>\n",
" <td>24.96</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Yes</td>\n",
" <td>Female</td>\n",
" <td>75-79</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>270558</th>\n",
" <td>No</td>\n",
" <td>25.84</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>65-69</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Very good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>60811</th>\n",
" <td>No</td>\n",
" <td>29.84</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>3.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>30-34</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Excellent</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>263613</th>\n",
" <td>Yes</td>\n",
" <td>32.92</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Yes</td>\n",
" <td>Female</td>\n",
" <td>50-54</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>268192</th>\n",
" <td>No</td>\n",
" <td>37.42</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>30.0</td>\n",
" <td>0.0</td>\n",
" <td>Yes</td>\n",
" <td>Female</td>\n",
" <td>60-64</td>\n",
" <td>Hispanic</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Poor</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50387</th>\n",
" <td>No</td>\n",
" <td>32.78</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>70-74</td>\n",
" <td>Black</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>15.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>255836 rows × 18 columns</p>\n",
"</div>"
],
"text/plain": [
" HeartDisease BMI Smoking AlcoholDrinking Stroke PhysicalHealth \\\n",
"203716 No 30.99 No No 0 0.0 \n",
"139550 No 32.61 No No 0 5.0 \n",
"314326 No 23.78 No No 0 0.0 \n",
"79716 No 30.38 No No 0 1.0 \n",
"23944 Yes 24.96 No No 0 0.0 \n",
"... ... ... ... ... ... ... \n",
"270558 No 25.84 No No 0 0.0 \n",
"60811 No 29.84 Yes No 0 0.0 \n",
"263613 Yes 32.92 Yes No 0 0.0 \n",
"268192 No 37.42 No No 0 30.0 \n",
"50387 No 32.78 No No 0 0.0 \n",
"\n",
" MentalHealth DiffWalking Sex AgeCategory Race Diabetic \\\n",
"203716 0.0 No Female 70-74 White No \n",
"139550 10.0 Yes Female 65-69 White Yes \n",
"314326 0.0 No Female 55-59 Other No \n",
"79716 30.0 Yes Female 80 or older White No \n",
"23944 0.0 Yes Female 75-79 White Yes \n",
"... ... ... ... ... ... ... \n",
"270558 0.0 No Male 65-69 White No \n",
"60811 3.0 No Male 30-34 White No \n",
"263613 0.0 Yes Female 50-54 White No \n",
"268192 0.0 Yes Female 60-64 Hispanic Yes \n",
"50387 0.0 No Female 70-74 Black Yes \n",
"\n",
" PhysicalActivity GenHealth SleepTime Asthma KidneyDisease SkinCancer \n",
"203716 Yes Fair 6.0 No No No \n",
"139550 No Good 5.0 Yes No No \n",
"314326 Yes Very good 7.0 No No No \n",
"79716 No Good 7.0 No No No \n",
"23944 Yes Good 8.0 No No No \n",
"... ... ... ... ... ... ... \n",
"270558 No Very good 8.0 No No No \n",
"60811 Yes Excellent 6.0 No No No \n",
"263613 Yes Good 8.0 No No No \n",
"268192 No Poor 6.0 No No No \n",
"50387 Yes Very good 15.0 No No No \n",
"\n",
"[255836 rows x 18 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Stroke</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>203716</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>139550</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>314326</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>79716</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23944</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>270558</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>60811</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>263613</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>268192</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50387</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>255836 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Stroke\n",
"203716 0\n",
"139550 0\n",
"314326 0\n",
"79716 0\n",
"23944 0\n",
"... ...\n",
"270558 0\n",
"60811 0\n",
"263613 0\n",
"268192 0\n",
"50387 0\n",
"\n",
"[255836 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" <th>BMI</th>\n",
" <th>Smoking</th>\n",
" <th>AlcoholDrinking</th>\n",
" <th>Stroke</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>DiffWalking</th>\n",
" <th>Sex</th>\n",
" <th>AgeCategory</th>\n",
" <th>Race</th>\n",
" <th>Diabetic</th>\n",
" <th>PhysicalActivity</th>\n",
" <th>GenHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>Asthma</th>\n",
" <th>KidneyDisease</th>\n",
" <th>SkinCancer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>86128</th>\n",
" <td>No</td>\n",
" <td>28.95</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>40-44</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29579</th>\n",
" <td>Yes</td>\n",
" <td>27.98</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>60-64</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>6.0</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9223</th>\n",
" <td>Yes</td>\n",
" <td>30.68</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>75-79</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>221689</th>\n",
" <td>No</td>\n",
" <td>23.73</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>65-69</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Excellent</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>42342</th>\n",
" <td>Yes</td>\n",
" <td>27.22</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>3.0</td>\n",
" <td>14.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>70-74</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Excellent</td>\n",
" <td>9.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23906</th>\n",
" <td>No</td>\n",
" <td>29.57</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>7.0</td>\n",
" <td>2.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>55-59</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75618</th>\n",
" <td>No</td>\n",
" <td>24.28</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>40-44</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Excellent</td>\n",
" <td>8.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>317847</th>\n",
" <td>No</td>\n",
" <td>27.96</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>60-64</td>\n",
" <td>Hispanic</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Good</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>169637</th>\n",
" <td>Yes</td>\n",
" <td>35.78</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>3.0</td>\n",
" <td>5.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>75-79</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>7.0</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>233255</th>\n",
" <td>No</td>\n",
" <td>32.69</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>50-54</td>\n",
" <td>Black</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Very good</td>\n",
" <td>7.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>63959 rows × 18 columns</p>\n",
"</div>"
],
"text/plain": [
" HeartDisease BMI Smoking AlcoholDrinking Stroke PhysicalHealth \\\n",
"86128 No 28.95 No No 0 0.0 \n",
"29579 Yes 27.98 No No 0 0.0 \n",
"9223 Yes 30.68 No No 0 0.0 \n",
"221689 No 23.73 No No 0 0.0 \n",
"42342 Yes 27.22 No No 0 3.0 \n",
"... ... ... ... ... ... ... \n",
"23906 No 29.57 Yes No 0 7.0 \n",
"75618 No 24.28 No No 0 0.0 \n",
"317847 No 27.96 No No 0 0.0 \n",
"169637 Yes 35.78 Yes No 0 3.0 \n",
"233255 No 32.69 Yes No 0 0.0 \n",
"\n",
" MentalHealth DiffWalking Sex AgeCategory Race Diabetic \\\n",
"86128 0.0 No Female 40-44 White No \n",
"29579 0.0 No Male 60-64 White No \n",
"9223 0.0 No Male 75-79 White Yes \n",
"221689 0.0 No Male 65-69 White No \n",
"42342 14.0 No Male 70-74 White No \n",
"... ... ... ... ... ... ... \n",
"23906 2.0 No Male 55-59 White No \n",
"75618 0.0 No Female 40-44 White No \n",
"317847 0.0 No Female 60-64 Hispanic No \n",
"169637 5.0 No Female 75-79 White No \n",
"233255 0.0 No Male 50-54 Black No \n",
"\n",
" PhysicalActivity GenHealth SleepTime Asthma KidneyDisease SkinCancer \n",
"86128 Yes Good 7.0 No No Yes \n",
"29579 Yes Good 6.0 Yes No No \n",
"9223 Yes Very good 8.0 No No No \n",
"221689 Yes Excellent 8.0 No No No \n",
"42342 Yes Excellent 9.0 No No No \n",
"... ... ... ... ... ... ... \n",
"23906 Yes Very good 7.0 No No No \n",
"75618 No Excellent 8.0 No No No \n",
"317847 No Good 6.0 No No No \n",
"169637 Yes Very good 7.0 Yes No No \n",
"233255 No Very good 7.0 No No No \n",
"\n",
"[63959 rows x 18 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Stroke</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>86128</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29579</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9223</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>221689</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>42342</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23906</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75618</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>317847</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>169637</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>233255</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>63959 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Stroke\n",
"86128 0\n",
"29579 0\n",
"9223 0\n",
"221689 0\n",
"42342 0\n",
"... ...\n",
"23906 0\n",
"75618 0\n",
"317847 0\n",
"169637 0\n",
"233255 0\n",
"\n",
"[63959 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" X = df_input # Contains all columns.\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
" # Split original dataframe into train and temp dataframes.\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
" # Split the temp dataframe into val and test dataframes.\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"Stroke\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=9\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Пропущенные значения по столбцам:\n",
"HeartDisease 0\n",
"BMI 0\n",
"Smoking 0\n",
"AlcoholDrinking 0\n",
"Stroke 0\n",
"PhysicalHealth 0\n",
"MentalHealth 0\n",
"DiffWalking 0\n",
"Sex 0\n",
"AgeCategory 0\n",
"Race 0\n",
"Diabetic 0\n",
"PhysicalActivity 0\n",
"GenHealth 0\n",
"SleepTime 0\n",
"Asthma 0\n",
"KidneyDisease 0\n",
"SkinCancer 0\n",
"dtype: int64\n",
"\n",
"Статистический обзор данных:\n",
" BMI Stroke PhysicalHealth MentalHealth \\\n",
"count 319795.000000 319795.000000 319795.00000 319795.000000 \n",
"mean 28.325399 0.037740 3.37171 3.898366 \n",
"std 6.356100 0.190567 7.95085 7.955235 \n",
"min 12.020000 0.000000 0.00000 0.000000 \n",
"25% 24.030000 0.000000 0.00000 0.000000 \n",
"50% 27.340000 0.000000 0.00000 0.000000 \n",
"75% 31.420000 0.000000 2.00000 3.000000 \n",
"max 94.850000 1.000000 30.00000 30.000000 \n",
"\n",
" SleepTime \n",
"count 319795.000000 \n",
"mean 7.097075 \n",
"std 1.436007 \n",
"min 1.000000 \n",
"25% 6.000000 \n",
"50% 7.000000 \n",
"75% 8.000000 \n",
"max 24.000000 \n"
]
}
],
"source": [
"null_values = df.isnull().sum()\n",
"print(\"Пропущенные значения по столбцам:\")\n",
"print(null_values)\n",
"\n",
"stat_summary = df.describe()\n",
"print(\"\\nСтатистический обзор данных:\")\n",
"print(stat_summary)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формируем конвеер для классификации данных и проверка конвеера"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>BMI</th>\n",
" <th>Stroke</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>HeartDisease_Yes</th>\n",
" <th>Smoking_Yes</th>\n",
" <th>AlcoholDrinking_Yes</th>\n",
" <th>DiffWalking_Yes</th>\n",
" <th>Diabetic_No, borderline diabetes</th>\n",
" <th>Diabetic_Yes</th>\n",
" <th>Diabetic_Yes (during pregnancy)</th>\n",
" <th>PhysicalActivity_Yes</th>\n",
" <th>GenHealth_Fair</th>\n",
" <th>GenHealth_Good</th>\n",
" <th>GenHealth_Poor</th>\n",
" <th>GenHealth_Very good</th>\n",
" <th>Asthma_Yes</th>\n",
" <th>KidneyDisease_Yes</th>\n",
" <th>SkinCancer_Yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>203716</th>\n",
" <td>0.417528</td>\n",
" <td>-0.198038</td>\n",
" <td>-0.424855</td>\n",
" <td>-0.490470</td>\n",
" <td>-0.764158</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>139550</th>\n",
" <td>0.671963</td>\n",
" <td>-0.198038</td>\n",
" <td>0.202871</td>\n",
" <td>0.765292</td>\n",
" <td>-1.461699</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>314326</th>\n",
" <td>-0.714865</td>\n",
" <td>-0.198038</td>\n",
" <td>-0.424855</td>\n",
" <td>-0.490470</td>\n",
" <td>-0.066617</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>79716</th>\n",
" <td>0.321722</td>\n",
" <td>-0.198038</td>\n",
" <td>-0.299310</td>\n",
" <td>3.276817</td>\n",
" <td>-0.066617</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23944</th>\n",
" <td>-0.529536</td>\n",
" <td>-0.198038</td>\n",
" <td>-0.424855</td>\n",
" <td>-0.490470</td>\n",
" <td>0.630924</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>270558</th>\n",
" <td>-0.391324</td>\n",
" <td>-0.198038</td>\n",
" <td>-0.424855</td>\n",
" <td>-0.490470</td>\n",
" <td>0.630924</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>60811</th>\n",
" <td>0.236911</td>\n",
" <td>-0.198038</td>\n",
" <td>-0.424855</td>\n",
" <td>-0.113741</td>\n",
" <td>-0.764158</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>263613</th>\n",
" <td>0.720651</td>\n",
" <td>-0.198038</td>\n",
" <td>-0.424855</td>\n",
" <td>-0.490470</td>\n",
" <td>0.630924</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>268192</th>\n",
" <td>1.427415</td>\n",
" <td>-0.198038</td>\n",
" <td>3.341502</td>\n",
" <td>-0.490470</td>\n",
" <td>-0.764158</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50387</th>\n",
" <td>0.698663</td>\n",
" <td>-0.198038</td>\n",
" <td>-0.424855</td>\n",
" <td>-0.490470</td>\n",
" <td>5.513713</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>255836 rows × 20 columns</p>\n",
"</div>"
],
"text/plain": [
" BMI Stroke PhysicalHealth MentalHealth SleepTime \\\n",
"203716 0.417528 -0.198038 -0.424855 -0.490470 -0.764158 \n",
"139550 0.671963 -0.198038 0.202871 0.765292 -1.461699 \n",
"314326 -0.714865 -0.198038 -0.424855 -0.490470 -0.066617 \n",
"79716 0.321722 -0.198038 -0.299310 3.276817 -0.066617 \n",
"23944 -0.529536 -0.198038 -0.424855 -0.490470 0.630924 \n",
"... ... ... ... ... ... \n",
"270558 -0.391324 -0.198038 -0.424855 -0.490470 0.630924 \n",
"60811 0.236911 -0.198038 -0.424855 -0.113741 -0.764158 \n",
"263613 0.720651 -0.198038 -0.424855 -0.490470 0.630924 \n",
"268192 1.427415 -0.198038 3.341502 -0.490470 -0.764158 \n",
"50387 0.698663 -0.198038 -0.424855 -0.490470 5.513713 \n",
"\n",
" HeartDisease_Yes Smoking_Yes AlcoholDrinking_Yes DiffWalking_Yes \\\n",
"203716 0.0 0.0 0.0 0.0 \n",
"139550 0.0 0.0 0.0 1.0 \n",
"314326 0.0 0.0 0.0 0.0 \n",
"79716 0.0 0.0 0.0 1.0 \n",
"23944 1.0 0.0 0.0 1.0 \n",
"... ... ... ... ... \n",
"270558 0.0 0.0 0.0 0.0 \n",
"60811 0.0 1.0 0.0 0.0 \n",
"263613 1.0 1.0 0.0 1.0 \n",
"268192 0.0 0.0 0.0 1.0 \n",
"50387 0.0 0.0 0.0 0.0 \n",
"\n",
" Diabetic_No, borderline diabetes Diabetic_Yes \\\n",
"203716 0.0 0.0 \n",
"139550 0.0 1.0 \n",
"314326 0.0 0.0 \n",
"79716 0.0 0.0 \n",
"23944 0.0 1.0 \n",
"... ... ... \n",
"270558 0.0 0.0 \n",
"60811 0.0 0.0 \n",
"263613 0.0 0.0 \n",
"268192 0.0 1.0 \n",
"50387 0.0 1.0 \n",
"\n",
" Diabetic_Yes (during pregnancy) PhysicalActivity_Yes GenHealth_Fair \\\n",
"203716 0.0 1.0 1.0 \n",
"139550 0.0 0.0 0.0 \n",
"314326 0.0 1.0 0.0 \n",
"79716 0.0 0.0 0.0 \n",
"23944 0.0 1.0 0.0 \n",
"... ... ... ... \n",
"270558 0.0 0.0 0.0 \n",
"60811 0.0 1.0 0.0 \n",
"263613 0.0 1.0 0.0 \n",
"268192 0.0 0.0 0.0 \n",
"50387 0.0 1.0 0.0 \n",
"\n",
" GenHealth_Good GenHealth_Poor GenHealth_Very good Asthma_Yes \\\n",
"203716 0.0 0.0 0.0 0.0 \n",
"139550 1.0 0.0 0.0 1.0 \n",
"314326 0.0 0.0 1.0 0.0 \n",
"79716 1.0 0.0 0.0 0.0 \n",
"23944 1.0 0.0 0.0 0.0 \n",
"... ... ... ... ... \n",
"270558 0.0 0.0 1.0 0.0 \n",
"60811 0.0 0.0 0.0 0.0 \n",
"263613 1.0 0.0 0.0 0.0 \n",
"268192 0.0 1.0 0.0 0.0 \n",
"50387 0.0 0.0 1.0 0.0 \n",
"\n",
" KidneyDisease_Yes SkinCancer_Yes \n",
"203716 0.0 0.0 \n",
"139550 0.0 0.0 \n",
"314326 0.0 0.0 \n",
"79716 0.0 0.0 \n",
"23944 0.0 0.0 \n",
"... ... ... \n",
"270558 0.0 0.0 \n",
"60811 0.0 0.0 \n",
"263613 0.0 0.0 \n",
"268192 0.0 0.0 \n",
"50387 0.0 0.0 \n",
"\n",
"[255836 rows x 20 columns]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"columns_to_drop = ['AgeCategory', 'Sex', 'Race']\n",
"num_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop and df[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop and df[column].dtype == \"object\"\n",
"]\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")\n",
"\n",
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формируем набор моделей"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=9)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=9\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=9,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучаем модели и тестируем их"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict, average=None)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict, average=None)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Матрица неточностей"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from matplotlib import pyplot as plt\n",
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"No stroke\", \"Stroke\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Precision_train</th>\n",
" <th>Precision_test</th>\n",
" <th>Recall_train</th>\n",
" <th>Recall_test</th>\n",
" <th>Accuracy_train</th>\n",
" <th>Accuracy_test</th>\n",
" <th>F1_train</th>\n",
" <th>F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>logistic</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ridge</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>decision_tree</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>knn</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>naive_bayes</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>gradient_boosting</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>random_forest</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mlp</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.999896</td>\n",
" <td>1.0</td>\n",
" <td>0.999996</td>\n",
" <td>1.0</td>\n",
" <td>[0.9999979689781726, 0.999948210678958]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Precision_train Precision_test Recall_train Recall_test \\\n",
"logistic 1.0 1.0 1.000000 1.0 \n",
"ridge 1.0 1.0 1.000000 1.0 \n",
"decision_tree 1.0 1.0 1.000000 1.0 \n",
"knn 1.0 1.0 1.000000 1.0 \n",
"naive_bayes 1.0 1.0 1.000000 1.0 \n",
"gradient_boosting 1.0 1.0 1.000000 1.0 \n",
"random_forest 1.0 1.0 1.000000 1.0 \n",
"mlp 1.0 1.0 0.999896 1.0 \n",
"\n",
" Accuracy_train Accuracy_test \\\n",
"logistic 1.000000 1.0 \n",
"ridge 1.000000 1.0 \n",
"decision_tree 1.000000 1.0 \n",
"knn 1.000000 1.0 \n",
"naive_bayes 1.000000 1.0 \n",
"gradient_boosting 1.000000 1.0 \n",
"random_forest 1.000000 1.0 \n",
"mlp 0.999996 1.0 \n",
"\n",
" F1_train F1_test \n",
"logistic [1.0, 1.0] [1.0, 1.0] \n",
"ridge [1.0, 1.0] [1.0, 1.0] \n",
"decision_tree [1.0, 1.0] [1.0, 1.0] \n",
"knn [1.0, 1.0] [1.0, 1.0] \n",
"naive_bayes [1.0, 1.0] [1.0, 1.0] \n",
"gradient_boosting [1.0, 1.0] [1.0, 1.0] \n",
"random_forest [1.0, 1.0] [1.0, 1.0] \n",
"mlp [0.9999979689781726, 0.999948210678958] [1.0, 1.0] "
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Accuracy_test</th>\n",
" <th>F1_test</th>\n",
" <th>ROC_AUC_test</th>\n",
" <th>Cohen_kappa_test</th>\n",
" <th>MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>logistic</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>decision_tree</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>knn</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>naive_bayes</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>random_forest</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>gradient_boosting</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mlp</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ridge</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Accuracy_test F1_test ROC_AUC_test Cohen_kappa_test \\\n",
"logistic 1.0 [1.0, 1.0] 1.0 1.0 \n",
"decision_tree 1.0 [1.0, 1.0] 1.0 1.0 \n",
"knn 1.0 [1.0, 1.0] 1.0 1.0 \n",
"naive_bayes 1.0 [1.0, 1.0] 1.0 1.0 \n",
"random_forest 1.0 [1.0, 1.0] 1.0 1.0 \n",
"gradient_boosting 1.0 [1.0, 1.0] 1.0 1.0 \n",
"mlp 1.0 [1.0, 1.0] 1.0 1.0 \n",
"ridge 1.0 [1.0, 1.0] 1.0 1.0 \n",
"\n",
" MCC_test \n",
"logistic 1.0 \n",
"decision_tree 1.0 \n",
"knn 1.0 \n",
"naive_bayes 1.0 \n",
"random_forest 1.0 \n",
"gradient_boosting 1.0 \n",
"mlp 1.0 \n",
"ridge 1.0 "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Лучшая модель"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Находим ошибки"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 0'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" <th>Predicted</th>\n",
" <th>BMI</th>\n",
" <th>Smoking</th>\n",
" <th>AlcoholDrinking</th>\n",
" <th>Stroke</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>DiffWalking</th>\n",
" <th>Sex</th>\n",
" <th>AgeCategory</th>\n",
" <th>Race</th>\n",
" <th>Diabetic</th>\n",
" <th>PhysicalActivity</th>\n",
" <th>GenHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>Asthma</th>\n",
" <th>KidneyDisease</th>\n",
" <th>SkinCancer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [HeartDisease, Predicted, BMI, Smoking, AlcoholDrinking, Stroke, PhysicalHealth, MentalHealth, DiffWalking, Sex, AgeCategory, Race, Diabetic, PhysicalActivity, GenHealth, SleepTime, Asthma, KidneyDisease, SkinCancer]\n",
"Index: []"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_new_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"Stroke\"] != y_new_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_new_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Пример использования модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" <th>BMI</th>\n",
" <th>Smoking</th>\n",
" <th>AlcoholDrinking</th>\n",
" <th>Stroke</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>DiffWalking</th>\n",
" <th>Sex</th>\n",
" <th>AgeCategory</th>\n",
" <th>Race</th>\n",
" <th>Diabetic</th>\n",
" <th>PhysicalActivity</th>\n",
" <th>GenHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>Asthma</th>\n",
" <th>KidneyDisease</th>\n",
" <th>SkinCancer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>8556</th>\n",
" <td>No</td>\n",
" <td>38.41</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>1</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>65-69</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" HeartDisease BMI Smoking AlcoholDrinking Stroke PhysicalHealth \\\n",
"8556 No 38.41 No No 1 0.0 \n",
"\n",
" MentalHealth DiffWalking Sex AgeCategory Race Diabetic \\\n",
"8556 0.0 No Female 65-69 White No \n",
"\n",
" PhysicalActivity GenHealth SleepTime Asthma KidneyDisease SkinCancer \n",
"8556 Yes Very good 6.0 No No No "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>BMI</th>\n",
" <th>Stroke</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>SleepTime</th>\n",
" <th>HeartDisease_Yes</th>\n",
" <th>Smoking_Yes</th>\n",
" <th>AlcoholDrinking_Yes</th>\n",
" <th>DiffWalking_Yes</th>\n",
" <th>Diabetic_No, borderline diabetes</th>\n",
" <th>Diabetic_Yes</th>\n",
" <th>Diabetic_Yes (during pregnancy)</th>\n",
" <th>PhysicalActivity_Yes</th>\n",
" <th>GenHealth_Fair</th>\n",
" <th>GenHealth_Good</th>\n",
" <th>GenHealth_Poor</th>\n",
" <th>GenHealth_Very good</th>\n",
" <th>Asthma_Yes</th>\n",
" <th>KidneyDisease_Yes</th>\n",
" <th>SkinCancer_Yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>8556</th>\n",
" <td>1.582904</td>\n",
" <td>5.049532</td>\n",
" <td>-0.424855</td>\n",
" <td>-0.49047</td>\n",
" <td>-0.764158</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" BMI Stroke PhysicalHealth MentalHealth SleepTime \\\n",
"8556 1.582904 5.049532 -0.424855 -0.49047 -0.764158 \n",
"\n",
" HeartDisease_Yes Smoking_Yes AlcoholDrinking_Yes DiffWalking_Yes \\\n",
"8556 0.0 0.0 0.0 0.0 \n",
"\n",
" Diabetic_No, borderline diabetes Diabetic_Yes \\\n",
"8556 0.0 0.0 \n",
"\n",
" Diabetic_Yes (during pregnancy) PhysicalActivity_Yes GenHealth_Fair \\\n",
"8556 0.0 1.0 0.0 \n",
"\n",
" GenHealth_Good GenHealth_Poor GenHealth_Very good Asthma_Yes \\\n",
"8556 0.0 0.0 1.0 0.0 \n",
"\n",
" KidneyDisease_Yes SkinCancer_Yes \n",
"8556 0.0 0.0 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: 1 (proba: [1.15647247e-04 9.99884353e-01])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 1'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 8556\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Создаем гиперпараметры методом поиска по сетке"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\User\\Desktop\\aim\\aimvenv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 5,\n",
" 'model__max_features': 'sqrt',\n",
" 'model__n_estimators': 50}"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"\n",
"optimized_model_type = 'random_forest'\n",
"random_state = 9\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 50, 100],\n",
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
" \"model__max_depth\": [5, 7, 10],\n",
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"optimized_model = ensemble.RandomForestClassifier(\n",
" random_state=42,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"sqrt\",\n",
" n_estimators=50,\n",
")\n",
"\n",
"result = {}\n",
"\n",
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели и сама оценка данных"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Precision_train</th>\n",
" <th>Precision_test</th>\n",
" <th>Recall_train</th>\n",
" <th>Recall_test</th>\n",
" <th>Accuracy_train</th>\n",
" <th>Accuracy_test</th>\n",
" <th>F1_train</th>\n",
" <th>F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th>Name</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Old</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>New</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Precision_train Precision_test Recall_train Recall_test Accuracy_train \\\n",
"Name \n",
"Old 1.0 1.0 1.0 1.0 1.0 \n",
"New 1.0 1.0 1.0 1.0 1.0 \n",
"\n",
" Accuracy_test F1_train F1_test \n",
"Name \n",
"Old 1.0 [1.0, 1.0] [1.0, 1.0] \n",
"New 1.0 1.0 1.0 "
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")\n",
"\n",
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Accuracy_test</th>\n",
" <th>F1_test</th>\n",
" <th>ROC_AUC_test</th>\n",
" <th>Cohen_kappa_test</th>\n",
" <th>MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th>Name</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Old</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>New</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Accuracy_test F1_test ROC_AUC_test Cohen_kappa_test MCC_test\n",
"Name \n",
"Old 1.0 [1.0, 1.0] 1.0 1.0 1.0\n",
"New 1.0 1.0 1.0 1.0 1.0"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"No stroke\", \"Stroke\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Модель идеально классифицировала объекты, которые относятся к \"No stroke\" и \"Stroke\"."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Регрессия"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Бизнес цель 2: \n",
"Предсказание среднего количества часов сна в день (SleepTime) на основе других факторов."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формируем выборки"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" <th>BMI</th>\n",
" <th>Smoking</th>\n",
" <th>AlcoholDrinking</th>\n",
" <th>Stroke</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>DiffWalking</th>\n",
" <th>Sex</th>\n",
" <th>AgeCategory</th>\n",
" <th>Race</th>\n",
" <th>Diabetic</th>\n",
" <th>PhysicalActivity</th>\n",
" <th>GenHealth</th>\n",
" <th>Asthma</th>\n",
" <th>KidneyDisease</th>\n",
" <th>SkinCancer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>95877</th>\n",
" <td>No</td>\n",
" <td>23.33</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>75-79</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>228939</th>\n",
" <td>Yes</td>\n",
" <td>27.46</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>30.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>55-59</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>260256</th>\n",
" <td>No</td>\n",
" <td>32.69</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>2.0</td>\n",
" <td>2.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>50-54</td>\n",
" <td>Hispanic</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Very good</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84785</th>\n",
" <td>No</td>\n",
" <td>31.32</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>25-29</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Excellent</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>83845</th>\n",
" <td>Yes</td>\n",
" <td>24.63</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>2.0</td>\n",
" <td>10.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>80 or older</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>119879</th>\n",
" <td>No</td>\n",
" <td>29.65</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>60-64</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Good</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>259178</th>\n",
" <td>No</td>\n",
" <td>42.60</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>5.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>35-39</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>131932</th>\n",
" <td>No</td>\n",
" <td>31.19</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>12.0</td>\n",
" <td>6.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>65-69</td>\n",
" <td>White</td>\n",
" <td>No, borderline diabetes</td>\n",
" <td>No</td>\n",
" <td>Very good</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>146867</th>\n",
" <td>No</td>\n",
" <td>22.24</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>7.0</td>\n",
" <td>5.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>18-24</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Excellent</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>121958</th>\n",
" <td>No</td>\n",
" <td>36.39</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>30-34</td>\n",
" <td>Black</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>255836 rows × 17 columns</p>\n",
"</div>"
],
"text/plain": [
" HeartDisease BMI Smoking AlcoholDrinking Stroke PhysicalHealth \\\n",
"95877 No 23.33 Yes No No 0.0 \n",
"228939 Yes 27.46 Yes No Yes 30.0 \n",
"260256 No 32.69 No No No 2.0 \n",
"84785 No 31.32 No No No 0.0 \n",
"83845 Yes 24.63 Yes No No 2.0 \n",
"... ... ... ... ... ... ... \n",
"119879 No 29.65 No No No 0.0 \n",
"259178 No 42.60 Yes No No 0.0 \n",
"131932 No 31.19 Yes No No 12.0 \n",
"146867 No 22.24 No No No 7.0 \n",
"121958 No 36.39 Yes No No 0.0 \n",
"\n",
" MentalHealth DiffWalking Sex AgeCategory Race \\\n",
"95877 0.0 No Male 75-79 White \n",
"228939 0.0 No Male 55-59 White \n",
"260256 2.0 No Male 50-54 Hispanic \n",
"84785 0.0 No Female 25-29 White \n",
"83845 10.0 No Male 80 or older White \n",
"... ... ... ... ... ... \n",
"119879 0.0 No Male 60-64 White \n",
"259178 5.0 No Male 35-39 White \n",
"131932 6.0 No Male 65-69 White \n",
"146867 5.0 No Female 18-24 White \n",
"121958 0.0 No Female 30-34 Black \n",
"\n",
" Diabetic PhysicalActivity GenHealth Asthma \\\n",
"95877 No Yes Very good No \n",
"228939 No Yes Good No \n",
"260256 No No Very good No \n",
"84785 No Yes Excellent No \n",
"83845 Yes Yes Good No \n",
"... ... ... ... ... \n",
"119879 No No Good No \n",
"259178 No Yes Good No \n",
"131932 No, borderline diabetes No Very good No \n",
"146867 No Yes Excellent No \n",
"121958 No Yes Good Yes \n",
"\n",
" KidneyDisease SkinCancer \n",
"95877 No No \n",
"228939 No No \n",
"260256 No No \n",
"84785 No No \n",
"83845 No No \n",
"... ... ... \n",
"119879 No No \n",
"259178 No No \n",
"131932 No No \n",
"146867 No No \n",
"121958 No No \n",
"\n",
"[255836 rows x 17 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>SleepTime</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>95877</th>\n",
" <td>7.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>228939</th>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>260256</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84785</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>83845</th>\n",
" <td>7.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>119879</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>259178</th>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>131932</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>146867</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>121958</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>255836 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" SleepTime\n",
"95877 7.0\n",
"228939 6.0\n",
"260256 8.0\n",
"84785 8.0\n",
"83845 7.0\n",
"... ...\n",
"119879 8.0\n",
"259178 6.0\n",
"131932 8.0\n",
"146867 8.0\n",
"121958 8.0\n",
"\n",
"[255836 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HeartDisease</th>\n",
" <th>BMI</th>\n",
" <th>Smoking</th>\n",
" <th>AlcoholDrinking</th>\n",
" <th>Stroke</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>DiffWalking</th>\n",
" <th>Sex</th>\n",
" <th>AgeCategory</th>\n",
" <th>Race</th>\n",
" <th>Diabetic</th>\n",
" <th>PhysicalActivity</th>\n",
" <th>GenHealth</th>\n",
" <th>Asthma</th>\n",
" <th>KidneyDisease</th>\n",
" <th>SkinCancer</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>271884</th>\n",
" <td>No</td>\n",
" <td>27.63</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>25.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>25-29</td>\n",
" <td>Hispanic</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>270361</th>\n",
" <td>No</td>\n",
" <td>21.95</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>20.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>30-34</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Excellent</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>219060</th>\n",
" <td>No</td>\n",
" <td>31.32</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>40-44</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>24010</th>\n",
" <td>No</td>\n",
" <td>40.35</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>30.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>65-69</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Good</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>181930</th>\n",
" <td>No</td>\n",
" <td>35.61</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>30.0</td>\n",
" <td>30.0</td>\n",
" <td>Yes</td>\n",
" <td>Female</td>\n",
" <td>60-64</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Fair</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>181387</th>\n",
" <td>No</td>\n",
" <td>28.06</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>15.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>80 or older</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Very good</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13791</th>\n",
" <td>No</td>\n",
" <td>29.68</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>7.0</td>\n",
" <td>25.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>35-39</td>\n",
" <td>Other</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Excellent</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>180164</th>\n",
" <td>No</td>\n",
" <td>21.11</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>4.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Female</td>\n",
" <td>35-39</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Good</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>94526</th>\n",
" <td>No</td>\n",
" <td>23.99</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>No</td>\n",
" <td>Male</td>\n",
" <td>70-74</td>\n",
" <td>White</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Excellent</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>107129</th>\n",
" <td>No</td>\n",
" <td>31.87</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>30.0</td>\n",
" <td>0.0</td>\n",
" <td>Yes</td>\n",
" <td>Male</td>\n",
" <td>60-64</td>\n",
" <td>White</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Poor</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>63959 rows × 17 columns</p>\n",
"</div>"
],
"text/plain": [
" HeartDisease BMI Smoking AlcoholDrinking Stroke PhysicalHealth \\\n",
"271884 No 27.63 Yes No No 0.0 \n",
"270361 No 21.95 No No No 0.0 \n",
"219060 No 31.32 Yes No No 0.0 \n",
"24010 No 40.35 No No No 30.0 \n",
"181930 No 35.61 Yes No No 30.0 \n",
"... ... ... ... ... ... ... \n",
"181387 No 28.06 Yes No No 0.0 \n",
"13791 No 29.68 Yes No No 7.0 \n",
"180164 No 21.11 No No No 4.0 \n",
"94526 No 23.99 No No No 0.0 \n",
"107129 No 31.87 Yes No No 30.0 \n",
"\n",
" MentalHealth DiffWalking Sex AgeCategory Race Diabetic \\\n",
"271884 25.0 No Female 25-29 Hispanic No \n",
"270361 20.0 No Female 30-34 White No \n",
"219060 0.0 No Female 40-44 White No \n",
"24010 0.0 No Female 65-69 White No \n",
"181930 30.0 Yes Female 60-64 White No \n",
"... ... ... ... ... ... ... \n",
"181387 15.0 No Male 80 or older White No \n",
"13791 25.0 No Male 35-39 Other No \n",
"180164 0.0 No Female 35-39 White No \n",
"94526 0.0 No Male 70-74 White No \n",
"107129 0.0 Yes Male 60-64 White Yes \n",
"\n",
" PhysicalActivity GenHealth Asthma KidneyDisease SkinCancer \n",
"271884 Yes Very good No No No \n",
"270361 Yes Excellent No No Yes \n",
"219060 Yes Very good Yes No No \n",
"24010 No Good No No No \n",
"181930 No Fair Yes No Yes \n",
"... ... ... ... ... ... \n",
"181387 Yes Very good No No Yes \n",
"13791 No Excellent Yes No No \n",
"180164 Yes Good No No Yes \n",
"94526 Yes Excellent No No No \n",
"107129 No Poor No No No \n",
"\n",
"[63959 rows x 17 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>SleepTime</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>271884</th>\n",
" <td>7.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>270361</th>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>219060</th>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>24010</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>181930</th>\n",
" <td>4.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>181387</th>\n",
" <td>7.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13791</th>\n",
" <td>3.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>180164</th>\n",
" <td>7.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>94526</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>107129</th>\n",
" <td>7.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>63959 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" SleepTime\n",
"271884 7.0\n",
"270361 6.0\n",
"219060 6.0\n",
"24010 8.0\n",
"181930 4.0\n",
"... ...\n",
"181387 7.0\n",
"13791 3.0\n",
"180164 7.0\n",
"94526 8.0\n",
"107129 7.0\n",
"\n",
"[63959 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"df = pd.read_csv(\".//static//csv//heart_2020_cleaned.csv\")\n",
"\n",
"def split_into_train_test(\n",
" df_input: DataFrame,\n",
" target_colname: str,\n",
" frac_train: float = 0.8,\n",
" random_state: int = None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if not (0 < frac_train < 1):\n",
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
" \n",
" if target_colname not in df_input.columns:\n",
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
" \n",
" X = df_input.drop(columns=[target_colname])\n",
" y = df_input[[target_colname]]\n",
"\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y,\n",
" test_size=(1.0 - frac_train),\n",
" random_state=random_state\n",
" )\n",
" return X_train, X_test, y_train, y_test\n",
"\n",
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
" df, \n",
" target_colname=\"SleepTime\", \n",
" frac_train=0.8, \n",
" random_state=42\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выполним one-hot encoding, чтобы избавиться от категориальных признаков"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>BMI</th>\n",
" <th>PhysicalHealth</th>\n",
" <th>MentalHealth</th>\n",
" <th>HeartDisease_Yes</th>\n",
" <th>Smoking_Yes</th>\n",
" <th>AlcoholDrinking_Yes</th>\n",
" <th>Stroke_Yes</th>\n",
" <th>DiffWalking_Yes</th>\n",
" <th>Sex_Male</th>\n",
" <th>AgeCategory_25-29</th>\n",
" <th>...</th>\n",
" <th>Diabetic_Yes</th>\n",
" <th>Diabetic_Yes (during pregnancy)</th>\n",
" <th>PhysicalActivity_Yes</th>\n",
" <th>GenHealth_Fair</th>\n",
" <th>GenHealth_Good</th>\n",
" <th>GenHealth_Poor</th>\n",
" <th>GenHealth_Very good</th>\n",
" <th>Asthma_Yes</th>\n",
" <th>KidneyDisease_Yes</th>\n",
" <th>SkinCancer_Yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>95877</th>\n",
" <td>23.33</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>228939</th>\n",
" <td>27.46</td>\n",
" <td>30.0</td>\n",
" <td>0.0</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>260256</th>\n",
" <td>32.69</td>\n",
" <td>2.0</td>\n",
" <td>2.0</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84785</th>\n",
" <td>31.32</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>83845</th>\n",
" <td>24.63</td>\n",
" <td>2.0</td>\n",
" <td>10.0</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>119879</th>\n",
" <td>29.65</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>259178</th>\n",
" <td>42.60</td>\n",
" <td>0.0</td>\n",
" <td>5.0</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>131932</th>\n",
" <td>31.19</td>\n",
" <td>12.0</td>\n",
" <td>6.0</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>146867</th>\n",
" <td>22.24</td>\n",
" <td>7.0</td>\n",
" <td>5.0</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>121958</th>\n",
" <td>36.39</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>255836 rows × 37 columns</p>\n",
"</div>"
],
"text/plain": [
" BMI PhysicalHealth MentalHealth HeartDisease_Yes Smoking_Yes \\\n",
"95877 23.33 0.0 0.0 False True \n",
"228939 27.46 30.0 0.0 True True \n",
"260256 32.69 2.0 2.0 False False \n",
"84785 31.32 0.0 0.0 False False \n",
"83845 24.63 2.0 10.0 True True \n",
"... ... ... ... ... ... \n",
"119879 29.65 0.0 0.0 False False \n",
"259178 42.60 0.0 5.0 False True \n",
"131932 31.19 12.0 6.0 False True \n",
"146867 22.24 7.0 5.0 False False \n",
"121958 36.39 0.0 0.0 False True \n",
"\n",
" AlcoholDrinking_Yes Stroke_Yes DiffWalking_Yes Sex_Male \\\n",
"95877 False False False True \n",
"228939 False True False True \n",
"260256 False False False True \n",
"84785 False False False False \n",
"83845 False False False True \n",
"... ... ... ... ... \n",
"119879 False False False True \n",
"259178 False False False True \n",
"131932 False False False True \n",
"146867 False False False False \n",
"121958 False False False False \n",
"\n",
" AgeCategory_25-29 ... Diabetic_Yes Diabetic_Yes (during pregnancy) \\\n",
"95877 False ... False False \n",
"228939 False ... False False \n",
"260256 False ... False False \n",
"84785 True ... False False \n",
"83845 False ... True False \n",
"... ... ... ... ... \n",
"119879 False ... False False \n",
"259178 False ... False False \n",
"131932 False ... False False \n",
"146867 False ... False False \n",
"121958 False ... False False \n",
"\n",
" PhysicalActivity_Yes GenHealth_Fair GenHealth_Good GenHealth_Poor \\\n",
"95877 True False False False \n",
"228939 True False True False \n",
"260256 False False False False \n",
"84785 True False False False \n",
"83845 True False True False \n",
"... ... ... ... ... \n",
"119879 False False True False \n",
"259178 True False True False \n",
"131932 False False False False \n",
"146867 True False False False \n",
"121958 True False True False \n",
"\n",
" GenHealth_Very good Asthma_Yes KidneyDisease_Yes SkinCancer_Yes \n",
"95877 True False False False \n",
"228939 False False False False \n",
"260256 True False False False \n",
"84785 False False False False \n",
"83845 False False False False \n",
"... ... ... ... ... \n",
"119879 False False False False \n",
"259178 False False False False \n",
"131932 True False False False \n",
"146867 False False False False \n",
"121958 False True False False \n",
"\n",
"[255836 rows x 37 columns]"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cat_features = ['HeartDisease', 'Smoking', 'AlcoholDrinking', 'Stroke',\n",
" 'DiffWalking', 'Sex', 'AgeCategory',\n",
" 'Race', 'Diabetic', 'PhysicalActivity', 'GenHealth',\n",
" 'Asthma', 'KidneyDisease', 'SkinCancer']\n",
"\n",
"X_test = pd.get_dummies(X_test, columns=cat_features, drop_first=True)\n",
"X_train = pd.get_dummies(X_train, columns=cat_features, drop_first=True)\n",
"\n",
"X_test\n",
"X_train"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Определение перечня алгоритмов решения задачи регрессии"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: linear\n",
"Model: linear_poly\n",
"Model: linear_interact\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import math\n",
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"\n",
"\n",
"models = {\n",
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPRegressor(\n",
" activation=\"tanh\",\n",
" hidden_layer_sizes=(3),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}\n",
"\n",
"for model_name in models.keys():\n",
" print(f\"Model: {model_name}\")\n",
"\n",
" fitted_model = models[model_name][\"model\"].fit(\n",
" X_train.values, y_train.values.ravel()\n",
" )\n",
" y_train_pred = fitted_model.predict(X_train.values)\n",
" y_test_pred = fitted_model.predict(X_test.values)\n",
" models[model_name][\"fitted\"] = fitted_model\n",
" models[model_name][\"train_preds\"] = y_train_pred\n",
" models[model_name][\"preds\"] = y_test_pred\n",
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_train, y_train_pred)\n",
" )\n",
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выводим результаты оценки"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_d0fa9_row0_col0, #T_d0fa9_row1_col0 {\n",
" background-color: #90d743;\n",
" color: #000000;\n",
"}\n",
"#T_d0fa9_row0_col1, #T_d0fa9_row1_col1, #T_d0fa9_row7_col0 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d0fa9_row0_col2, #T_d0fa9_row1_col2 {\n",
" background-color: #5302a3;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d0fa9_row0_col3, #T_d0fa9_row1_col3, #T_d0fa9_row7_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d0fa9_row2_col0 {\n",
" background-color: #a2da37;\n",
" color: #000000;\n",
"}\n",
"#T_d0fa9_row2_col1, #T_d0fa9_row3_col1, #T_d0fa9_row4_col1 {\n",
" background-color: #25848e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d0fa9_row2_col2 {\n",
" background-color: #5801a4;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d0fa9_row2_col3, #T_d0fa9_row3_col3, #T_d0fa9_row4_col3 {\n",
" background-color: #d6556d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d0fa9_row3_col0, #T_d0fa9_row4_col0 {\n",
" background-color: #a5db36;\n",
" color: #000000;\n",
"}\n",
"#T_d0fa9_row3_col2, #T_d0fa9_row4_col2, #T_d0fa9_row6_col2 {\n",
" background-color: #5601a4;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d0fa9_row5_col0 {\n",
" background-color: #9bd93c;\n",
" color: #000000;\n",
"}\n",
"#T_d0fa9_row5_col1 {\n",
" background-color: #24878e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d0fa9_row5_col2, #T_d0fa9_row7_col3 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d0fa9_row5_col3 {\n",
" background-color: #d5536f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d0fa9_row6_col0, #T_d0fa9_row7_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_d0fa9_row6_col1 {\n",
" background-color: #228b8d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d0fa9_row6_col3 {\n",
" background-color: #d14e72;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_d0fa9\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_d0fa9_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
" <th id=\"T_d0fa9_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
" <th id=\"T_d0fa9_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
" <th id=\"T_d0fa9_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_d0fa9_level0_row0\" class=\"row_heading level0 row0\" >linear_poly</th>\n",
" <td id=\"T_d0fa9_row0_col0\" class=\"data row0 col0\" >1.397226</td>\n",
" <td id=\"T_d0fa9_row0_col1\" class=\"data row0 col1\" >1.413139</td>\n",
" <td id=\"T_d0fa9_row0_col2\" class=\"data row0 col2\" >0.999215</td>\n",
" <td id=\"T_d0fa9_row0_col3\" class=\"data row0 col3\" >0.042532</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_d0fa9_level0_row1\" class=\"row_heading level0 row1\" >linear_interact</th>\n",
" <td id=\"T_d0fa9_row1_col0\" class=\"data row1 col0\" >1.397316</td>\n",
" <td id=\"T_d0fa9_row1_col1\" class=\"data row1 col1\" >1.413193</td>\n",
" <td id=\"T_d0fa9_row1_col2\" class=\"data row1 col2\" >0.999216</td>\n",
" <td id=\"T_d0fa9_row1_col3\" class=\"data row1 col3\" >0.042460</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_d0fa9_level0_row2\" class=\"row_heading level0 row2\" >mlp</th>\n",
" <td id=\"T_d0fa9_row2_col0\" class=\"data row2 col0\" >1.404383</td>\n",
" <td id=\"T_d0fa9_row2_col1\" class=\"data row2 col1\" >1.416410</td>\n",
" <td id=\"T_d0fa9_row2_col2\" class=\"data row2 col2\" >1.000126</td>\n",
" <td id=\"T_d0fa9_row2_col3\" class=\"data row2 col3\" >0.038095</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_d0fa9_level0_row3\" class=\"row_heading level0 row3\" >linear</th>\n",
" <td id=\"T_d0fa9_row3_col0\" class=\"data row3 col0\" >1.405231</td>\n",
" <td id=\"T_d0fa9_row3_col1\" class=\"data row3 col1\" >1.416610</td>\n",
" <td id=\"T_d0fa9_row3_col2\" class=\"data row3 col2\" >0.999855</td>\n",
" <td id=\"T_d0fa9_row3_col3\" class=\"data row3 col3\" >0.037823</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_d0fa9_level0_row4\" class=\"row_heading level0 row4\" >ridge</th>\n",
" <td id=\"T_d0fa9_row4_col0\" class=\"data row4 col0\" >1.405231</td>\n",
" <td id=\"T_d0fa9_row4_col1\" class=\"data row4 col1\" >1.416611</td>\n",
" <td id=\"T_d0fa9_row4_col2\" class=\"data row4 col2\" >0.999852</td>\n",
" <td id=\"T_d0fa9_row4_col3\" class=\"data row4 col3\" >0.037821</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_d0fa9_level0_row5\" class=\"row_heading level0 row5\" >random_forest</th>\n",
" <td id=\"T_d0fa9_row5_col0\" class=\"data row5 col0\" >1.401999</td>\n",
" <td id=\"T_d0fa9_row5_col1\" class=\"data row5 col1\" >1.418929</td>\n",
" <td id=\"T_d0fa9_row5_col2\" class=\"data row5 col2\" >0.998045</td>\n",
" <td id=\"T_d0fa9_row5_col3\" class=\"data row5 col3\" >0.034671</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_d0fa9_level0_row6\" class=\"row_heading level0 row6\" >decision_tree</th>\n",
" <td id=\"T_d0fa9_row6_col0\" class=\"data row6 col0\" >1.406670</td>\n",
" <td id=\"T_d0fa9_row6_col1\" class=\"data row6 col1\" >1.422338</td>\n",
" <td id=\"T_d0fa9_row6_col2\" class=\"data row6 col2\" >0.999876</td>\n",
" <td id=\"T_d0fa9_row6_col3\" class=\"data row6 col3\" >0.030026</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_d0fa9_level0_row7\" class=\"row_heading level0 row7\" >knn</th>\n",
" <td id=\"T_d0fa9_row7_col0\" class=\"data row7 col0\" >1.296527</td>\n",
" <td id=\"T_d0fa9_row7_col1\" class=\"data row7 col1\" >1.507555</td>\n",
" <td id=\"T_d0fa9_row7_col2\" class=\"data row7 col2\" >1.039156</td>\n",
" <td id=\"T_d0fa9_row7_col3\" class=\"data row7 col3\" >-0.089685</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1d24abd35f0>"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выводим лучшую модель"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'linear_poly'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подбираем гиперпараметры методом поиска по сетке"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 8 candidates, totalling 24 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\User\\Desktop\\aim\\aimvenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие параметры: {'max_depth': 10, 'min_samples_split': 10, 'n_estimators': 50}\n",
"Лучший результат (MSE): 1.9790374490880065\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestRegressor\n",
"\n",
"\n",
"X = df[['HeartDisease', 'BMI', 'Smoking', 'AlcoholDrinking', 'Stroke',\n",
" 'PhysicalHealth', 'MentalHealth', 'DiffWalking',\n",
" 'Diabetic', 'PhysicalActivity', 'GenHealth',\n",
" 'Asthma', 'KidneyDisease', 'SkinCancer']]\n",
"y = df['SleepTime'] \n",
"\n",
"model = RandomForestRegressor() \n",
"\n",
"param_grid = {\n",
" 'n_estimators': [50, 100], \n",
" 'max_depth': [10, 20], \n",
" 'min_samples_split': [5, 10] \n",
"}\n",
"\n",
"grid_search = GridSearchCV(estimator=model, param_grid=param_grid,\n",
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
"\n",
"grid_search.fit(X_train, y_train)\n",
"\n",
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
"print(\"Лучший результат (MSE):\", -grid_search.best_score_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучаем модель с новыми гиперпараметрами и сравниваем новых данных со старыми"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 8 candidates, totalling 24 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\User\\Desktop\\aim\\aimvenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"c:\\Users\\User\\Desktop\\aim\\aimvenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"c:\\Users\\User\\Desktop\\aim\\aimvenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"c:\\Users\\User\\Desktop\\aim\\aimvenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"c:\\Users\\User\\Desktop\\aim\\aimvenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"c:\\Users\\User\\Desktop\\aim\\aimvenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Старые параметры: {'max_depth': 10, 'min_samples_split': 10, 'n_estimators': 100}\n",
"Лучший результат (MSE) на старых параметрах: 1.9789879323889759\n",
"\n",
"Новые параметры: {'max_depth': 10, 'min_samples_split': 5, 'n_estimators': 100}\n",
"Лучший результат (MSE) на новых параметрах: 1.9835849471109568\n",
"Среднеквадратическая ошибка (MSE) на тестовых данных: 2.005535804883726\n",
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 1.416169412494044\n"
]
}
],
"source": [
"# Old data\n",
"\n",
"old_param_grid = param_grid\n",
"old_grid_search = grid_search\n",
"old_grid_search.fit(X_train, y_train)\n",
"\n",
"old_best_params = old_grid_search.best_params_\n",
"old_best_mse = -old_grid_search.best_score_ \n",
"\n",
"# New data\n",
"\n",
"new_param_grid = {\n",
" 'n_estimators': [100],\n",
" 'max_depth': [10],\n",
" 'min_samples_split': [5]\n",
" }\n",
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
" param_grid=new_param_grid,\n",
" scoring='neg_mean_squared_error', cv=2)\n",
"\n",
"new_grid_search.fit(X_train, y_train)\n",
"\n",
"new_best_params = new_grid_search.best_params_\n",
"new_best_mse = -new_grid_search.best_score_\n",
"\n",
"new_best_model = RandomForestRegressor(**new_best_params)\n",
"new_best_model.fit(X_train, y_train)\n",
"\n",
"old_best_model = RandomForestRegressor(**old_best_params)\n",
"old_best_model.fit(X_train, y_train)\n",
"\n",
"y_new_pred = new_best_model.predict(X_test)\n",
"y_old_pred = old_best_model.predict(X_test)\n",
"\n",
"mse = metrics.mean_squared_error(y_test, y_new_pred)\n",
"rmse = np.sqrt(mse)\n",
"\n",
"print(\"Старые параметры:\", old_best_params)\n",
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
"print(\"\\nНовые параметры:\", new_best_params)\n",
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Визуализация данных"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1600x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(16, 8))\n",
"plt.scatter(range(len(y_test)), y_test, label=\"Истинные значения\", color=\"black\", alpha=0.5)\n",
"plt.scatter(range(len(y_test)), y_new_pred, label=\"Предсказанные (новые параметры)\", color=\"blue\", alpha=0.5)\n",
"plt.scatter(range(len(y_test)), y_old_pred, label=\"Предсказанные (старые параметры)\", color=\"red\", alpha=0.5)\n",
"plt.xlabel(\"Выборка\")\n",
"plt.ylabel(\"Значения\")\n",
"plt.legend()\n",
"plt.title(\"Сравнение предсказанных и истинных значений\")\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "aimvenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}