3437 lines
520 KiB
Plaintext
3437 lines
520 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Данные по инсультам\n",
|
|||
|
"\n",
|
|||
|
"Выведем информацию о столбцах датасета:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 83,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>gender</th>\n",
|
|||
|
" <th>age</th>\n",
|
|||
|
" <th>hypertension</th>\n",
|
|||
|
" <th>heart_disease</th>\n",
|
|||
|
" <th>ever_married</th>\n",
|
|||
|
" <th>work_type</th>\n",
|
|||
|
" <th>Residence_type</th>\n",
|
|||
|
" <th>avg_glucose_level</th>\n",
|
|||
|
" <th>bmi</th>\n",
|
|||
|
" <th>smoking_status</th>\n",
|
|||
|
" <th>stroke</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>id</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>9046</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>67.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>228.69</td>\n",
|
|||
|
" <td>36.6</td>\n",
|
|||
|
" <td>formerly smoked</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>51676</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>61.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Self-employed</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>202.21</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>31112</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>80.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>105.92</td>\n",
|
|||
|
" <td>32.5</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>60182</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>49.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>171.23</td>\n",
|
|||
|
" <td>34.4</td>\n",
|
|||
|
" <td>smokes</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1665</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>79.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Self-employed</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>174.12</td>\n",
|
|||
|
" <td>24.0</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>18234</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>80.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>83.75</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>44873</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>81.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Self-employed</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>125.20</td>\n",
|
|||
|
" <td>40.0</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>19723</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>35.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Self-employed</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>82.99</td>\n",
|
|||
|
" <td>30.6</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>37544</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>51.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>166.29</td>\n",
|
|||
|
" <td>25.6</td>\n",
|
|||
|
" <td>formerly smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>44679</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>44.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Govt_job</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>85.28</td>\n",
|
|||
|
" <td>26.2</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>5110 rows × 11 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" gender age hypertension heart_disease ever_married work_type \\\n",
|
|||
|
"id \n",
|
|||
|
"9046 Male 67.0 0 1 Yes Private \n",
|
|||
|
"51676 Female 61.0 0 0 Yes Self-employed \n",
|
|||
|
"31112 Male 80.0 0 1 Yes Private \n",
|
|||
|
"60182 Female 49.0 0 0 Yes Private \n",
|
|||
|
"1665 Female 79.0 1 0 Yes Self-employed \n",
|
|||
|
"... ... ... ... ... ... ... \n",
|
|||
|
"18234 Female 80.0 1 0 Yes Private \n",
|
|||
|
"44873 Female 81.0 0 0 Yes Self-employed \n",
|
|||
|
"19723 Female 35.0 0 0 Yes Self-employed \n",
|
|||
|
"37544 Male 51.0 0 0 Yes Private \n",
|
|||
|
"44679 Female 44.0 0 0 Yes Govt_job \n",
|
|||
|
"\n",
|
|||
|
" Residence_type avg_glucose_level bmi smoking_status stroke \n",
|
|||
|
"id \n",
|
|||
|
"9046 Urban 228.69 36.6 formerly smoked 1 \n",
|
|||
|
"51676 Rural 202.21 NaN never smoked 1 \n",
|
|||
|
"31112 Rural 105.92 32.5 never smoked 1 \n",
|
|||
|
"60182 Urban 171.23 34.4 smokes 1 \n",
|
|||
|
"1665 Rural 174.12 24.0 never smoked 1 \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"18234 Urban 83.75 NaN never smoked 0 \n",
|
|||
|
"44873 Urban 125.20 40.0 never smoked 0 \n",
|
|||
|
"19723 Rural 82.99 30.6 never smoked 0 \n",
|
|||
|
"37544 Rural 166.29 25.6 formerly smoked 0 \n",
|
|||
|
"44679 Urban 85.28 26.2 Unknown 0 \n",
|
|||
|
"\n",
|
|||
|
"[5110 rows x 11 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 83,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"from sklearn import set_config\n",
|
|||
|
"\n",
|
|||
|
"set_config(transform_output=\"pandas\")\n",
|
|||
|
"\n",
|
|||
|
"random_state=9\n",
|
|||
|
"\n",
|
|||
|
"df = pd.read_csv(\"..//..//static//csv//healthcare-dataset-stroke-data.csv\", index_col=\"id\")\n",
|
|||
|
"\n",
|
|||
|
"df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Бизнес-цели\n",
|
|||
|
"\n",
|
|||
|
"### Классификация\n",
|
|||
|
"\n",
|
|||
|
"Цель: разработать модель классификации, которая сможет предсказать возможность возникновения инсульта у человека на основе социально-демографических факторов, состояния здоровья и образа жизни.\n",
|
|||
|
"\n",
|
|||
|
"Применение:\n",
|
|||
|
"\n",
|
|||
|
"1. Медицинские учреждения: модель может использоваться для раннего выявления пациентов с высоким риском инсульта, что позволит предпринять профилактические меры и уменьшить вероятность серьезных последствий.\n",
|
|||
|
"2. Системы поддержки принятия медицинских решений: модель может быть встроена в электронные медицинские карты для автоматического предупреждения врачей о пациентах, находящихся в зоне повышенного риска.\n",
|
|||
|
"3. Образовательные программы: модель может помочь повысить осведомленность населения о факторах риска инсульта и способах их снижения, что также может улучшить профилактику заболеваний.\n",
|
|||
|
"\n",
|
|||
|
"### Регрессия\n",
|
|||
|
"\n",
|
|||
|
"Цель: разработать модель регрессии для прогнозирования уровня глюкозы в крови человека на основе социально-демографических факторов, состояния здоровья и образа жизни. Модель позволит определить тенденцию к повышению или снижению уровня глюкозы и, в дальнейшем, оценить возможные риски, связанные с состоянием пациента.\n",
|
|||
|
"\n",
|
|||
|
"Применение:\n",
|
|||
|
"\n",
|
|||
|
"1. Медицинские учреждения: помощь в раннем выявлении пациентов с потенциально высоким уровнем глюкозы для контроля и назначения профилактических мер, снижающих риск диабета и других осложнений.\n",
|
|||
|
"2. Системы поддержки принятия медицинских решений: интеграция модели в медицинские записи позволит врачам получать оценку уровня глюкозы, что упростит мониторинг и ведение пациентов, особенно при отсутствии лабораторных данных в реальном времени.\n",
|
|||
|
"3. Образовательные программы и общественное здравоохранение: с помощью модели можно повысить осведомленность населения о факторах, влияющих на уровень глюкозы, и предлагать рекомендации по улучшению образа жизни для поддержания нормального уровня глюкозы.\n",
|
|||
|
"\n",
|
|||
|
"## Достижимый уровень качества модели\n",
|
|||
|
"\n",
|
|||
|
"Модель классификации для предсказания инсульта на основе этого набора данных может достигнуть хорошего качества, но с некоторыми ограничениями. \n",
|
|||
|
"\n",
|
|||
|
"- Информативные признаки: Датасет содержит признаки, которые являются важными факторами риска инсульта (например, возраст, наличие гипертонии и сердечных заболеваний). Эти данные, скорее всего, дают модели достаточно информации для распознавания группы повышенного риска.\n",
|
|||
|
"\n",
|
|||
|
"- Ограничения данных: Несмотря на наличие ключевых медицинских факторов, датасет не включает генетические данные, специфическую историю заболеваний или детализированные данные о питании и физической активности, которые также влияют на риск инсульта. Это может ограничить максимальное качество модели.\n",
|
|||
|
"\n",
|
|||
|
"Для задачи регрессии по предсказанию уровня глюкозы в крови модель может также достигнуть хорошего качества с ограниченной точностью.\n",
|
|||
|
"\n",
|
|||
|
"- Информативность признаков: Данные содержат признаки, которые связаны с уровнем глюкозы (например, возраст, курение, наличие гипертонии), и их можно использовать для создания прогноза на общем уровне.\n",
|
|||
|
"\n",
|
|||
|
"- Недостающие факторы: Уровень глюкозы сильно зависит от питания, уровня физической активности, гормональных изменений, которые не представлены в данных. Из-за этого модель будет иметь ограниченную точность при оценке этого параметра."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Классификация\n",
|
|||
|
"\n",
|
|||
|
"Разделим набор данных на на обучающую и тестовые выборки (80/20). Целевой признак - stroke"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 84,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from typing import Tuple\n",
|
|||
|
"from pandas import DataFrame\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"\n",
|
|||
|
"def split_stratified_into_train_val_test(\n",
|
|||
|
" df_input,\n",
|
|||
|
" stratify_colname=\"y\",\n",
|
|||
|
" frac_train=0.6,\n",
|
|||
|
" frac_val=0.15,\n",
|
|||
|
" frac_test=0.25,\n",
|
|||
|
" random_state=None,\n",
|
|||
|
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" Splits a Pandas dataframe into three subsets (train, val, and test)\n",
|
|||
|
" following fractional ratios provided by the user, where each subset is\n",
|
|||
|
" stratified by the values in a specific column (that is, each subset has\n",
|
|||
|
" the same relative frequency of the values in the column). It performs this\n",
|
|||
|
" splitting by running train_test_split() twice.\n",
|
|||
|
"\n",
|
|||
|
" Parameters\n",
|
|||
|
" ----------\n",
|
|||
|
" df_input : Pandas dataframe\n",
|
|||
|
" Input dataframe to be split.\n",
|
|||
|
" stratify_colname : str\n",
|
|||
|
" The name of the column that will be used for stratification. Usually\n",
|
|||
|
" this column would be for the label.\n",
|
|||
|
" frac_train : float\n",
|
|||
|
" frac_val : float\n",
|
|||
|
" frac_test : float\n",
|
|||
|
" The ratios with which the dataframe will be split into train, val, and\n",
|
|||
|
" test data. The values should be expressed as float fractions and should\n",
|
|||
|
" sum to 1.0.\n",
|
|||
|
" random_state : int, None, or RandomStateInstance\n",
|
|||
|
" Value to be passed to train_test_split().\n",
|
|||
|
"\n",
|
|||
|
" Returns\n",
|
|||
|
" -------\n",
|
|||
|
" df_train, df_val, df_test :\n",
|
|||
|
" Dataframes containing the three splits.\n",
|
|||
|
" \"\"\"\n",
|
|||
|
"\n",
|
|||
|
" if frac_train + frac_val + frac_test != 1.0:\n",
|
|||
|
" raise ValueError(\n",
|
|||
|
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
|
|||
|
" % (frac_train, frac_val, frac_test)\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" if stratify_colname not in df_input.columns:\n",
|
|||
|
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
|
|||
|
"\n",
|
|||
|
" X = df_input # Contains all columns.\n",
|
|||
|
" y = df_input[\n",
|
|||
|
" [stratify_colname]\n",
|
|||
|
" ] # Dataframe of just the column on which to stratify.\n",
|
|||
|
"\n",
|
|||
|
" # Split original dataframe into train and temp dataframes.\n",
|
|||
|
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
|
|||
|
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" if frac_val <= 0:\n",
|
|||
|
" assert len(df_input) == len(df_train) + len(df_temp)\n",
|
|||
|
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
|
|||
|
"\n",
|
|||
|
" # Split the temp dataframe into val and test dataframes.\n",
|
|||
|
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
|
|||
|
" df_val, df_test, y_val, y_test = train_test_split(\n",
|
|||
|
" df_temp,\n",
|
|||
|
" y_temp,\n",
|
|||
|
" stratify=y_temp,\n",
|
|||
|
" test_size=relative_frac_test,\n",
|
|||
|
" random_state=random_state,\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
|
|||
|
" return df_train, df_val, df_test, y_train, y_val, y_test"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 85,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>gender</th>\n",
|
|||
|
" <th>age</th>\n",
|
|||
|
" <th>hypertension</th>\n",
|
|||
|
" <th>heart_disease</th>\n",
|
|||
|
" <th>ever_married</th>\n",
|
|||
|
" <th>work_type</th>\n",
|
|||
|
" <th>Residence_type</th>\n",
|
|||
|
" <th>avg_glucose_level</th>\n",
|
|||
|
" <th>bmi</th>\n",
|
|||
|
" <th>smoking_status</th>\n",
|
|||
|
" <th>stroke</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>id</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>22159</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>54.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>97.06</td>\n",
|
|||
|
" <td>28.5</td>\n",
|
|||
|
" <td>formerly smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>8920</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>51.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Self-employed</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>76.35</td>\n",
|
|||
|
" <td>33.5</td>\n",
|
|||
|
" <td>formerly smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>65507</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>33.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>55.72</td>\n",
|
|||
|
" <td>38.2</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>43196</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>52.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Self-employed</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>59.54</td>\n",
|
|||
|
" <td>42.2</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>59745</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>27.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>76.74</td>\n",
|
|||
|
" <td>53.9</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>66546</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>20.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>80.08</td>\n",
|
|||
|
" <td>25.1</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>68798</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>58.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>59.86</td>\n",
|
|||
|
" <td>28.0</td>\n",
|
|||
|
" <td>formerly smoked</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>61409</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>32.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Govt_job</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>58.24</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>formerly smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>69259</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>77.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>100.85</td>\n",
|
|||
|
" <td>29.5</td>\n",
|
|||
|
" <td>smokes</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>17231</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>24.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>90.42</td>\n",
|
|||
|
" <td>24.3</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>4088 rows × 11 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" gender age hypertension heart_disease ever_married work_type \\\n",
|
|||
|
"id \n",
|
|||
|
"22159 Female 54.0 1 0 No Private \n",
|
|||
|
"8920 Female 51.0 0 0 Yes Self-employed \n",
|
|||
|
"65507 Male 33.0 0 0 Yes Private \n",
|
|||
|
"43196 Female 52.0 0 0 Yes Self-employed \n",
|
|||
|
"59745 Female 27.0 0 0 Yes Private \n",
|
|||
|
"... ... ... ... ... ... ... \n",
|
|||
|
"66546 Female 20.0 0 0 No Private \n",
|
|||
|
"68798 Female 58.0 0 0 Yes Private \n",
|
|||
|
"61409 Male 32.0 1 0 No Govt_job \n",
|
|||
|
"69259 Female 77.0 0 0 Yes Private \n",
|
|||
|
"17231 Female 24.0 0 0 No Private \n",
|
|||
|
"\n",
|
|||
|
" Residence_type avg_glucose_level bmi smoking_status stroke \n",
|
|||
|
"id \n",
|
|||
|
"22159 Urban 97.06 28.5 formerly smoked 0 \n",
|
|||
|
"8920 Rural 76.35 33.5 formerly smoked 0 \n",
|
|||
|
"65507 Rural 55.72 38.2 never smoked 0 \n",
|
|||
|
"43196 Urban 59.54 42.2 Unknown 0 \n",
|
|||
|
"59745 Urban 76.74 53.9 Unknown 0 \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"66546 Urban 80.08 25.1 never smoked 0 \n",
|
|||
|
"68798 Rural 59.86 28.0 formerly smoked 1 \n",
|
|||
|
"61409 Urban 58.24 NaN formerly smoked 0 \n",
|
|||
|
"69259 Rural 100.85 29.5 smokes 0 \n",
|
|||
|
"17231 Urban 90.42 24.3 never smoked 0 \n",
|
|||
|
"\n",
|
|||
|
"[4088 rows x 11 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>stroke</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>id</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>22159</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>8920</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>65507</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>43196</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>59745</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>66546</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>68798</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>61409</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>69259</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>17231</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>4088 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" stroke\n",
|
|||
|
"id \n",
|
|||
|
"22159 0\n",
|
|||
|
"8920 0\n",
|
|||
|
"65507 0\n",
|
|||
|
"43196 0\n",
|
|||
|
"59745 0\n",
|
|||
|
"... ...\n",
|
|||
|
"66546 0\n",
|
|||
|
"68798 1\n",
|
|||
|
"61409 0\n",
|
|||
|
"69259 0\n",
|
|||
|
"17231 0\n",
|
|||
|
"\n",
|
|||
|
"[4088 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>gender</th>\n",
|
|||
|
" <th>age</th>\n",
|
|||
|
" <th>hypertension</th>\n",
|
|||
|
" <th>heart_disease</th>\n",
|
|||
|
" <th>ever_married</th>\n",
|
|||
|
" <th>work_type</th>\n",
|
|||
|
" <th>Residence_type</th>\n",
|
|||
|
" <th>avg_glucose_level</th>\n",
|
|||
|
" <th>bmi</th>\n",
|
|||
|
" <th>smoking_status</th>\n",
|
|||
|
" <th>stroke</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>id</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>18072</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>39.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Govt_job</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>107.47</td>\n",
|
|||
|
" <td>21.3</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>67063</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>62.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Self-employed</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>130.56</td>\n",
|
|||
|
" <td>36.1</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>40387</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>17.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>77.46</td>\n",
|
|||
|
" <td>24.0</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>18032</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>62.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>90.61</td>\n",
|
|||
|
" <td>25.8</td>\n",
|
|||
|
" <td>smokes</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>5478</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>60.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Self-employed</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>203.04</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>smokes</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>57710</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>50.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>112.25</td>\n",
|
|||
|
" <td>21.6</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>63043</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>27.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>61.80</td>\n",
|
|||
|
" <td>26.8</td>\n",
|
|||
|
" <td>formerly smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>63986</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>60.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>153.48</td>\n",
|
|||
|
" <td>37.3</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>28461</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>15.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Never_worked</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>79.59</td>\n",
|
|||
|
" <td>28.4</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>54975</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Self-employed</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>64.06</td>\n",
|
|||
|
" <td>18.9</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>1022 rows × 11 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" gender age hypertension heart_disease ever_married work_type \\\n",
|
|||
|
"id \n",
|
|||
|
"18072 Female 39.0 0 0 Yes Govt_job \n",
|
|||
|
"67063 Male 62.0 0 0 Yes Self-employed \n",
|
|||
|
"40387 Female 17.0 0 0 No Private \n",
|
|||
|
"18032 Male 62.0 0 1 Yes Private \n",
|
|||
|
"5478 Female 60.0 0 0 Yes Self-employed \n",
|
|||
|
"... ... ... ... ... ... ... \n",
|
|||
|
"57710 Female 50.0 0 0 Yes Private \n",
|
|||
|
"63043 Female 27.0 0 0 No Private \n",
|
|||
|
"63986 Male 60.0 0 0 Yes Private \n",
|
|||
|
"28461 Male 15.0 0 0 No Never_worked \n",
|
|||
|
"54975 Male 7.0 0 0 No Self-employed \n",
|
|||
|
"\n",
|
|||
|
" Residence_type avg_glucose_level bmi smoking_status stroke \n",
|
|||
|
"id \n",
|
|||
|
"18072 Urban 107.47 21.3 Unknown 0 \n",
|
|||
|
"67063 Urban 130.56 36.1 Unknown 0 \n",
|
|||
|
"40387 Rural 77.46 24.0 Unknown 0 \n",
|
|||
|
"18032 Rural 90.61 25.8 smokes 0 \n",
|
|||
|
"5478 Urban 203.04 NaN smokes 0 \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"57710 Rural 112.25 21.6 Unknown 0 \n",
|
|||
|
"63043 Urban 61.80 26.8 formerly smoked 0 \n",
|
|||
|
"63986 Rural 153.48 37.3 never smoked 0 \n",
|
|||
|
"28461 Rural 79.59 28.4 Unknown 0 \n",
|
|||
|
"54975 Rural 64.06 18.9 Unknown 0 \n",
|
|||
|
"\n",
|
|||
|
"[1022 rows x 11 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>stroke</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>id</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>18072</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>67063</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>40387</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>18032</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>5478</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>57710</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>63043</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>63986</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>28461</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>54975</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>1022 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" stroke\n",
|
|||
|
"id \n",
|
|||
|
"18072 0\n",
|
|||
|
"67063 0\n",
|
|||
|
"40387 0\n",
|
|||
|
"18032 0\n",
|
|||
|
"5478 0\n",
|
|||
|
"... ...\n",
|
|||
|
"57710 0\n",
|
|||
|
"63043 0\n",
|
|||
|
"63986 0\n",
|
|||
|
"28461 0\n",
|
|||
|
"54975 0\n",
|
|||
|
"\n",
|
|||
|
"[1022 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
|
|||
|
" df, stratify_colname=\"stroke\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_train\", X_train)\n",
|
|||
|
"display(\"y_train\", y_train)\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_test\", X_test)\n",
|
|||
|
"display(\"y_test\", y_test)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Выберем ориентир для задачи классификации. Для этого применим алгоритм случайного предсказания, т.е. в каждом случае в качестве предсказания выберем случайный класс."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 86,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Baseline Accuracy: 0.5205479452054794\n",
|
|||
|
"Baseline Precision: 0.05823293172690763\n",
|
|||
|
"Baseline Recall: 0.58\n",
|
|||
|
"Baseline F1 Score: 0.10583941605839416\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score\n",
|
|||
|
"\n",
|
|||
|
"# Получаем уникальные классы для целевого признака из тренировочного набора данных\n",
|
|||
|
"unique_classes = np.unique(y_train)\n",
|
|||
|
"\n",
|
|||
|
"# Генерируем случайные предсказания, выбирая случайное значение из области значений целевого признака\n",
|
|||
|
"random_predictions = np.random.choice(unique_classes, size=len(y_test))\n",
|
|||
|
"\n",
|
|||
|
"# Вычисление метрик для ориентира\n",
|
|||
|
"baseline_accuracy = accuracy_score(y_test, random_predictions)\n",
|
|||
|
"baseline_precision = precision_score(y_test, random_predictions)\n",
|
|||
|
"baseline_recall = recall_score(y_test, random_predictions)\n",
|
|||
|
"baseline_f1 = f1_score(y_test, random_predictions)\n",
|
|||
|
"\n",
|
|||
|
"print('Baseline Accuracy:', baseline_accuracy)\n",
|
|||
|
"print('Baseline Precision:', baseline_precision)\n",
|
|||
|
"print('Baseline Recall:', baseline_recall)\n",
|
|||
|
"print('Baseline F1 Score:', baseline_f1)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Были использованы следующие метрики:\n",
|
|||
|
"\n",
|
|||
|
"- Accuracy: показывает долю верных предсказаний из общего числа примеров. Интуитивно понятная, но практически бесполезная в задачах с неравными классами метрика, т.к. не показывает насколько эффективно модель предсказывает более редкий класс.\n",
|
|||
|
"- Precision: доля истинно положительных примеров среди всех предсказанных положительных. Precision показывает, насколько модель избирательна при предсказании положительного класса. Это важно, если ложные срабатывания нежелательны (например, ошибочное предсказание инсульта).\n",
|
|||
|
"- Recall: доля объектов положительного класса из всех объектов положительного класса, которую нашел алгоритм. Recall показывает способность модели распознавать все примеры положительного класса. В рассматриваемой задаче высокое значение recall важно, так как минимизирует пропуски случаев инсульта.\n",
|
|||
|
"- F1 Score: гармоническое среднее precision и recall, балансирует их значения. Эта метрика важна, если нам необходимо учитывать как точность, так и полноту модели. F1-score особенно полезен, если классы несбалансированы и важно найти баланс между обнаружением всех случаев (recall) и минимизацией ложных срабатываний (precision).\n",
|
|||
|
"\n",
|
|||
|
"Можно увидеть, что данные метрики охватывают разные аспекты работы модели, от способности распознавать редкие классы до общего уровня точности, что позволяет взглянуть на работу модели с разных сторон."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Сформируем конвейер для классификации"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 87,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.discriminant_analysis import StandardScaler\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.preprocessing import OneHotEncoder\n",
|
|||
|
"\n",
|
|||
|
"columns_to_drop = [\"work_type\", \"stroke\"]\n",
|
|||
|
"columns_not_to_modify = [\"hypertension\", \"heart_disease\"]\n",
|
|||
|
"\n",
|
|||
|
"num_columns = [\n",
|
|||
|
" column\n",
|
|||
|
" for column in df.columns\n",
|
|||
|
" if column not in columns_to_drop\n",
|
|||
|
" and column not in columns_not_to_modify\n",
|
|||
|
" and df[column].dtype != \"object\"\n",
|
|||
|
"]\n",
|
|||
|
"\n",
|
|||
|
"cat_columns = [\n",
|
|||
|
" column\n",
|
|||
|
" for column in df.columns\n",
|
|||
|
" if column not in columns_to_drop\n",
|
|||
|
" and column not in columns_not_to_modify\n",
|
|||
|
" and df[column].dtype == \"object\"\n",
|
|||
|
"]\n",
|
|||
|
"\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\"\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"drop_columns = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\",\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"pipeline_end = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" (\"drop_columns\", drop_columns),\n",
|
|||
|
" ]\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Теперь проверим работу конвейера:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 88,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>age</th>\n",
|
|||
|
" <th>avg_glucose_level</th>\n",
|
|||
|
" <th>bmi</th>\n",
|
|||
|
" <th>gender_Male</th>\n",
|
|||
|
" <th>gender_Other</th>\n",
|
|||
|
" <th>ever_married_Yes</th>\n",
|
|||
|
" <th>Residence_type_Urban</th>\n",
|
|||
|
" <th>smoking_status_formerly smoked</th>\n",
|
|||
|
" <th>smoking_status_never smoked</th>\n",
|
|||
|
" <th>smoking_status_smokes</th>\n",
|
|||
|
" <th>hypertension</th>\n",
|
|||
|
" <th>heart_disease</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>id</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>22159</th>\n",
|
|||
|
" <td>0.472344</td>\n",
|
|||
|
" <td>-0.194427</td>\n",
|
|||
|
" <td>-0.059214</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>8920</th>\n",
|
|||
|
" <td>0.339807</td>\n",
|
|||
|
" <td>-0.653763</td>\n",
|
|||
|
" <td>0.587887</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>65507</th>\n",
|
|||
|
" <td>-0.455418</td>\n",
|
|||
|
" <td>-1.111325</td>\n",
|
|||
|
" <td>1.196162</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>43196</th>\n",
|
|||
|
" <td>0.383986</td>\n",
|
|||
|
" <td>-1.026600</td>\n",
|
|||
|
" <td>1.713843</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>59745</th>\n",
|
|||
|
" <td>-0.720492</td>\n",
|
|||
|
" <td>-0.645113</td>\n",
|
|||
|
" <td>3.228060</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>66546</th>\n",
|
|||
|
" <td>-1.029746</td>\n",
|
|||
|
" <td>-0.571034</td>\n",
|
|||
|
" <td>-0.499243</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>68798</th>\n",
|
|||
|
" <td>0.649060</td>\n",
|
|||
|
" <td>-1.019502</td>\n",
|
|||
|
" <td>-0.123924</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>61409</th>\n",
|
|||
|
" <td>-0.499597</td>\n",
|
|||
|
" <td>-1.055433</td>\n",
|
|||
|
" <td>-0.098040</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>69259</th>\n",
|
|||
|
" <td>1.488464</td>\n",
|
|||
|
" <td>-0.110367</td>\n",
|
|||
|
" <td>0.070206</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>17231</th>\n",
|
|||
|
" <td>-0.853030</td>\n",
|
|||
|
" <td>-0.341699</td>\n",
|
|||
|
" <td>-0.602779</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>4088 rows × 12 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" age avg_glucose_level bmi gender_Male gender_Other \\\n",
|
|||
|
"id \n",
|
|||
|
"22159 0.472344 -0.194427 -0.059214 0.0 0.0 \n",
|
|||
|
"8920 0.339807 -0.653763 0.587887 0.0 0.0 \n",
|
|||
|
"65507 -0.455418 -1.111325 1.196162 1.0 0.0 \n",
|
|||
|
"43196 0.383986 -1.026600 1.713843 0.0 0.0 \n",
|
|||
|
"59745 -0.720492 -0.645113 3.228060 0.0 0.0 \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"66546 -1.029746 -0.571034 -0.499243 0.0 0.0 \n",
|
|||
|
"68798 0.649060 -1.019502 -0.123924 0.0 0.0 \n",
|
|||
|
"61409 -0.499597 -1.055433 -0.098040 1.0 0.0 \n",
|
|||
|
"69259 1.488464 -0.110367 0.070206 0.0 0.0 \n",
|
|||
|
"17231 -0.853030 -0.341699 -0.602779 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" ever_married_Yes Residence_type_Urban smoking_status_formerly smoked \\\n",
|
|||
|
"id \n",
|
|||
|
"22159 0.0 1.0 1.0 \n",
|
|||
|
"8920 1.0 0.0 1.0 \n",
|
|||
|
"65507 1.0 0.0 0.0 \n",
|
|||
|
"43196 1.0 1.0 0.0 \n",
|
|||
|
"59745 1.0 1.0 0.0 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"66546 0.0 1.0 0.0 \n",
|
|||
|
"68798 1.0 0.0 1.0 \n",
|
|||
|
"61409 0.0 1.0 1.0 \n",
|
|||
|
"69259 1.0 0.0 0.0 \n",
|
|||
|
"17231 0.0 1.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" smoking_status_never smoked smoking_status_smokes hypertension \\\n",
|
|||
|
"id \n",
|
|||
|
"22159 0.0 0.0 1 \n",
|
|||
|
"8920 0.0 0.0 0 \n",
|
|||
|
"65507 1.0 0.0 0 \n",
|
|||
|
"43196 0.0 0.0 0 \n",
|
|||
|
"59745 0.0 0.0 0 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"66546 1.0 0.0 0 \n",
|
|||
|
"68798 0.0 0.0 0 \n",
|
|||
|
"61409 0.0 0.0 1 \n",
|
|||
|
"69259 0.0 1.0 0 \n",
|
|||
|
"17231 1.0 0.0 0 \n",
|
|||
|
"\n",
|
|||
|
" heart_disease \n",
|
|||
|
"id \n",
|
|||
|
"22159 0 \n",
|
|||
|
"8920 0 \n",
|
|||
|
"65507 0 \n",
|
|||
|
"43196 0 \n",
|
|||
|
"59745 0 \n",
|
|||
|
"... ... \n",
|
|||
|
"66546 0 \n",
|
|||
|
"68798 0 \n",
|
|||
|
"61409 0 \n",
|
|||
|
"69259 0 \n",
|
|||
|
"17231 0 \n",
|
|||
|
"\n",
|
|||
|
"[4088 rows x 12 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 88,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"preprocessed_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Подберем оптимальные гиперпараметры для каждой из выбранных моделей методом поиска по сетке и сформируем их набор.\n",
|
|||
|
"\n",
|
|||
|
"knn -- k-ближайших соседей\n",
|
|||
|
"\n",
|
|||
|
"random_forest -- метод случайного леса (набор деревьев решений)\n",
|
|||
|
"\n",
|
|||
|
"mlp -- многослойный персептрон (нейронная сеть)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Лучшие параметры для knn: {'n_neighbors': 1, 'weights': 'uniform'}\n",
|
|||
|
"Лучшие параметры для random_forest: {'class_weight': 'balanced_subsample', 'criterion': 'entropy', 'max_depth': 7, 'max_features': 'sqrt', 'n_estimators': 50, 'random_state': 9}\n",
|
|||
|
"Лучшие параметры для mlp: {'alpha': np.float64(0.1), 'early_stopping': True, 'hidden_layer_sizes': np.int64(14), 'max_iter': 1000, 'random_state': 9, 'solver': 'adam'}\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.model_selection import GridSearchCV\n",
|
|||
|
"from sklearn import neighbors, ensemble, neural_network\n",
|
|||
|
"\n",
|
|||
|
"# Словарь с вариантами гиперпараметров для каждой модели\n",
|
|||
|
"param_grids = {\n",
|
|||
|
" \"knn\": {\n",
|
|||
|
" \"n_neighbors\": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], \n",
|
|||
|
" \"weights\": ['uniform', 'distance']\n",
|
|||
|
" },\n",
|
|||
|
" \"random_forest\": {\n",
|
|||
|
" \"n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
|
|||
|
" \"max_features\": [\"sqrt\", \"log2\", 2],\n",
|
|||
|
" \"max_depth\": [2, 3, 4, 5, 6, 7, 8, 9, 10],\n",
|
|||
|
" \"criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
|
|||
|
" \"random_state\": [random_state],\n",
|
|||
|
" \"class_weight\": [\"balanced\", \"balanced_subsample\"]\n",
|
|||
|
" },\n",
|
|||
|
" \"mlp\": {\n",
|
|||
|
" \"solver\": ['adam'], \n",
|
|||
|
" \"max_iter\": [1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000], \n",
|
|||
|
" \"alpha\": 10.0 ** -np.arange(1, 10), \n",
|
|||
|
" \"hidden_layer_sizes\":np.arange(10, 15), \n",
|
|||
|
" \"early_stopping\": [True, False],\n",
|
|||
|
" \"random_state\": [random_state]\n",
|
|||
|
" }\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Создаем экземпляры моделей\n",
|
|||
|
"models = {\n",
|
|||
|
" \"knn\": neighbors.KNeighborsClassifier(),\n",
|
|||
|
" \"random_forest\": ensemble.RandomForestClassifier(),\n",
|
|||
|
" \"mlp\": neural_network.MLPClassifier()\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Словарь для хранения моделей с их лучшими параметрами\n",
|
|||
|
"class_models = {}\n",
|
|||
|
"\n",
|
|||
|
"# Выполнение поиска по сетке для каждой модели\n",
|
|||
|
"for model_name, model in models.items():\n",
|
|||
|
" # Создаем GridSearchCV для текущей модели\n",
|
|||
|
" gs_optimizer = GridSearchCV(estimator=model, param_grid=param_grids[model_name], scoring=\"f1\", n_jobs=-1)\n",
|
|||
|
" \n",
|
|||
|
" # Обучаем GridSearchCV\n",
|
|||
|
" gs_optimizer.fit(preprocessed_df, y_train.values.ravel())\n",
|
|||
|
" \n",
|
|||
|
" # Получаем лучшие параметры\n",
|
|||
|
" best_params = gs_optimizer.best_params_\n",
|
|||
|
" print(f\"Лучшие параметры для {model_name}: {best_params}\")\n",
|
|||
|
" \n",
|
|||
|
" class_models[model_name] = {\n",
|
|||
|
" \"model\": model.set_params(**best_params) # Настраиваем модель с лучшими параметрами\n",
|
|||
|
" }"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Далее обучим модели и оценим их качество."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 90,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: knn\n",
|
|||
|
"Model: random_forest\n",
|
|||
|
"Model: mlp\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import confusion_matrix\n",
|
|||
|
"\n",
|
|||
|
"for model_name in class_models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" model = class_models[model_name][\"model\"]\n",
|
|||
|
"\n",
|
|||
|
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
|
|||
|
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
|
|||
|
"\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train)\n",
|
|||
|
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
|
|||
|
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
|
|||
|
"\n",
|
|||
|
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
|||
|
" class_models[model_name][\"probs\"] = y_test_probs\n",
|
|||
|
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
|||
|
"\n",
|
|||
|
" class_models[model_name][\"Precision_train\"] = precision_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Precision_test\"] = precision_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Recall_train\"] = recall_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Recall_test\"] = recall_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Accuracy_train\"] = accuracy_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Accuracy_test\"] = accuracy_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" ) \n",
|
|||
|
" class_models[model_name][\"F1_train\"] = f1_score(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"F1_test\"] = f1_score(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"Confusion_matrix\"] = confusion_matrix(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Матрицы неточностей:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 91,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABBIAAANrCAYAAAD70rtBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAC8qklEQVR4nOzdfXzN9f/H8efZ9Ww7G2IzZsjlihQ1QynGSKX4VWrlIvFNVBSVvrloyqILIhddyMU36qsLvkIKhTCSUkKuQ9gUZi7a1Tmf3x9y6pwNZ2ezq8/jfrt9brfO+/N+f87r7Hu+28vr836/PxbDMAwBAAAAAAC4waukAwAAAAAAAGUHhQQAAAAAAOA2CgkAAAAAAMBtFBIAAAAAAIDbKCQAAAAAAAC3UUgAAAAAAABuo5AAAAAAAADcRiEBAAAAAAC4jUICAAAAAABwG4UEAJKkUaNGyWKx6I8//ijpUAAAQDGyWCwaNWpUSYdxQf/5z3/UsGFD+fr6KiwsrKTDASAKCQAAAABKqV9++UW9evXSlVdeqXfeeUdvv/12SYeUx+HDhzVq1Cht3ry5pEMBio1PSQcAAAAAAPlZuXKl7Ha73njjDdWtW7ekw8nX4cOH9cILL6hWrVpq2rRpSYcDFAtmJAAAAAClyJkzZ0o6hFLj6NGjklSkSxrOnj1bZNcCzIpCAoAL2r9/v+rWraurr75aaWlpuvnmm3X11Vdr27ZtuuWWW1ShQgVVr15d48aNcxq3cuVKWSwWzZs3Ty+99JJq1KihgIAAtWvXTrt37y6hTwMAQOlzfo+ibdu26f7771fFihXVunVr/fTTT+rVq5fq1KmjgIAARURE6KGHHtKxY8fyHb9792716tVLYWFhCg0NVe/evfP8gzkrK0uDBw9WlSpVFBISojvuuEO//fZbvnH98MMP6tSpk6xWq4KDg9WuXTutX7/eqc/MmTNlsVi0Zs0aPf7446pSpYrCwsL0r3/9S9nZ2UpPT1ePHj1UsWJFVaxYUU8//bQMw3D7Z1OrVi2NHDlSklSlSpU8ezlMmTJFV111lfz9/RUZGakBAwYoPT3d6Rrnc5dNmzbppptuUoUKFfTcc885fh4jR45U3bp15e/vr6ioKD399NPKyspyusayZcvUunVrhYWFKTg4WA0aNHBcY+XKlbr++uslSb1795bFYpHFYtHMmTPd/pxAWcTSBgD52rNnj9q2batKlSpp2bJluuKKKyRJJ06cUMeOHdW1a1fdc889+vjjj/XMM8+ocePG6tSpk9M1Xn75ZXl5eWnIkCE6efKkxo0bp8TERG3YsKEkPhIAAKXW3XffrXr16mnMmDEyDEPLli3T3r171bt3b0VERGjr1q16++23tXXrVq1fv14Wi8Vp/D333KPatWsrOTlZ33//vd59911VrVpVY8eOdfR5+OGH9f777+v+++9Xy5Yt9dVXX6lz5855Ytm6datuvPFGWa1WPf300/L19dVbb72lm2++WatWrVJsbKxT/8cee0wRERF64YUXtH79er399tsKCwvTunXrVLNmTY0ZM0ZLlizRK6+8oquvvlo9evRw62cyYcIEzZ49W/Pnz9fUqVMVHBysJk2aSDpXQHnhhRcUHx+v/v37a8eOHZo6dao2btyotWvXytfX13GdY8eOqVOnTurevbseeOABhYeHy26364477tCaNWvUr18/NWrUSFu2bNH48eO1c+dOLViwwPGzuO2229SkSRMlJSXJ399fu3fv1tq1ayVJjRo1UlJSkkaMGKF+/frpxhtvlCS1bNnSrc8IlFkGABiGMXLkSEOS8fvvvxvbt283IiMjjeuvv944fvy4o0+bNm0MScbs2bMdbVlZWUZERITRrVs3R9vXX39tSDIaNWpkZGVlOdrfeOMNQ5KxZcuW4vlQAACUcuf//t53331O7WfPns3T94MPPjAkGatXr84z/qGHHnLqe9dddxmVK1d2vN68ebMhyXj00Ued+t1///2GJGPkyJGOtjvvvNPw8/Mz9uzZ42g7fPiwERISYtx0002OthkzZhiSjISEBMNutzva4+LiDIvFYjzyyCOOttzcXKNGjRpGmzZtLvETcfbP/OS8o0ePGn5+fkaHDh0Mm83maH/zzTcNScZ7773naDufu0ybNs3puv/5z38MLy8v45tvvnFqnzZtmiHJWLt2rWEYhjF+/Pg87+9q48aNhiRjxowZBfpsQFnG0gYATn7++We1adNGtWrV0vLly1WxYkWn88HBwXrggQccr/38/HTDDTdo7969ea7Vu3dv+fn5OV6fr9Ln1xcAADN75JFHnF4HBgY6/jszM1N//PGHWrRoIUn6/vvvLzn+xhtv1LFjx5SRkSFJWrJkiSTp8ccfd+o3aNAgp9c2m01ffvml7rzzTtWpU8fRXq1aNd1///1as2aN45rn9enTx2mGRGxsrAzDUJ8+fRxt3t7eat68eZHkAMuXL1d2drYGDRokL6+//znTt29fWa1WLV682Km/v7+/evfu7dT20UcfqVGjRmrYsKH++OMPx9G2bVtJ0tdffy3p770Z/ve//8lutxc6dqC8oJAAwMntt9+ukJAQffHFF7JarXnO16hRI890yooVK+rEiRN5+tasWTNPP0n59gUAwMxq167t9Pr48eN64oknFB4ersDAQFWpUsXR5+TJk3nGX+pv7v79++Xl5aUrr7zSqV+DBg2cXv/+++86e/Zsnnbp3DR+u92ugwcPXvS9Q0NDJUlRUVF52osiB9i/f3++sfv5+alOnTqO8+dVr17d6caGJO3atUtbt25VlSpVnI769etL+nuTx3vvvVetWrXSww8/rPDwcHXv3l3z5s2jqADTY48EAE66deumWbNmac6cOfrXv/6V57y3t3e+44x8Nk8qSF8AAMzsnzMQpHN7Hqxbt05Dhw5V06ZNFRwcLLvdro4dO+b7j9iS/Jt7offOr70kcgDXn60k2e12NW7cWK+//nq+Y84XQQIDA7V69Wp9/fXXWrx4sZYuXar//ve/atu2rb788ssLfnagvKOQAMDJK6+8Ih8fHz366KMKCQnR/fffX9IhAQBgKidOnNCKFSv0wgsvaMSIEY72Xbt2eXzN6Oho2e127dmzx+lO/o4dO5z6ValSRRUqVMjTLkm//PKLvLy88sw0KG7R0dGSzsX+z+UX2dnZ2rdvn+Lj4y95jSuvvFI//vij2rVrl2empSsvLy+1a9dO7dq10+uvv64xY8bo3//+t77++mvFx8dfcjxQHrG0AYATi8Wit99+W//3f/+nnj17auHChSUdEgAApnL+Lrfr3fsJEyZ4fM3zT1aaOHHiRa/p7e2tDh066H//+59+/fVXR3taWprmzp2r1q1b57v0sTjFx8fLz89PEydOdPoZTZ8+XSdPnsz3SRSu7rnnHh06dEjvvPNOnnN//vmnzpw5I+ncEhNXTZs2lSTHYyKDgoIkKc+jJ4HyjBkJAPLw8vLS+++/rzvvvFP33HOPlixZ4th8CAAAXF5Wq1U33XSTxo0bp5ycHFWvXl1ffvml9u3b5/E1mzZtqvvuu09TpkzRyZMn1bJlS61YsUK7d+/O0/fFF1/UsmXL1Lp1az366KPy8fHRW2+9paysLI0bN64wH61IVKlSRcOGDdMLL7ygjh076o477tCOHTs0ZcoUXX/99U6bQl/Igw8+qHnz5umRRx7R119/rVatWslms+mXX37RvHnz9MUXX6h58+ZKSkrS6tWr1blzZ0VHR+vo0aOaMmWKatSoodatW0s6N7shLCxM06ZNU0hIiIKCghQbG5tn3wugPKGQACBfvr6++vjjj9WpUyd16dJFy5cvL+mQAAAwjblz5+qxxx7T5MmTZRiGOnTooM8//1yRkZEeX/O9995TlSpVNGfOHC1YsEBt27bV4sWL8yxVuOqqq/TNN99o2LBhSk5Olt1uV2xsrN5//33FxsYW9qMViVGjRqlKlSp68803NXjwYFWqVEn9+vXTmDFj5Ovre8nxXl5eWrBggcaPH6/Zs2dr/vz5qlChgurUqaMnnnjCseniHXfcoV9//VXvvfee/vjjD11xxRV
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1200x1000 with 6 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"fig, ax = plt.subplots(2, 2, figsize=(12, 10))\n",
|
|||
|
"\n",
|
|||
|
"for index, (key, model_info) in enumerate(class_models.items()):\n",
|
|||
|
" c_matrix = model_info[\"Confusion_matrix\"]\n",
|
|||
|
" \n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[\"Not stroke\", \"Stroke\"]\n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
" \n",
|
|||
|
" disp.ax_.set_title(key)\n",
|
|||
|
"\n",
|
|||
|
"if len(class_models) < len(ax.flat):\n",
|
|||
|
" for i in range(len(class_models), len(ax.flat)):\n",
|
|||
|
" fig.delaxes(ax.flat[i])\n",
|
|||
|
"\n",
|
|||
|
"plt.subplots_adjust(top=0.9, bottom=0.1, hspace=0.4, wspace=0.3)\n",
|
|||
|
"\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Precision, Recall, Accuracy, F1:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 92,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_a7559_row0_col0 {\n",
|
|||
|
" background-color: #1f988b;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_a7559_row0_col1, #T_a7559_row1_col0, #T_a7559_row1_col2, #T_a7559_row2_col3 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_a7559_row0_col2, #T_a7559_row0_col3, #T_a7559_row1_col1, #T_a7559_row2_col0 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_a7559_row0_col4 {\n",
|
|||
|
" background-color: #b7318a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_a7559_row0_col5, #T_a7559_row1_col4, #T_a7559_row1_col6, #T_a7559_row2_col7 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_a7559_row0_col6, #T_a7559_row0_col7, #T_a7559_row2_col4, #T_a7559_row2_col5 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_a7559_row1_col3, #T_a7559_row2_col1 {\n",
|
|||
|
" background-color: #1f968b;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_a7559_row1_col5 {\n",
|
|||
|
" background-color: #be3885;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_a7559_row1_col7 {\n",
|
|||
|
" background-color: #9c179e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_a7559_row2_col2 {\n",
|
|||
|
" background-color: #86d549;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_a7559_row2_col6 {\n",
|
|||
|
" background-color: #8808a6;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_a7559\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_a7559_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
|||
|
" <th id=\"T_a7559_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
|||
|
" <th id=\"T_a7559_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
|||
|
" <th id=\"T_a7559_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
|||
|
" <th id=\"T_a7559_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
|||
|
" <th id=\"T_a7559_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_a7559_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
|||
|
" <th id=\"T_a7559_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_a7559_level0_row0\" class=\"row_heading level0 row0\" >mlp</th>\n",
|
|||
|
" <td id=\"T_a7559_row0_col0\" class=\"data row0 col0\" >0.400000</td>\n",
|
|||
|
" <td id=\"T_a7559_row0_col1\" class=\"data row0 col1\" >0.200000</td>\n",
|
|||
|
" <td id=\"T_a7559_row0_col2\" class=\"data row0 col2\" >0.020101</td>\n",
|
|||
|
" <td id=\"T_a7559_row0_col3\" class=\"data row0 col3\" >0.020000</td>\n",
|
|||
|
" <td id=\"T_a7559_row0_col4\" class=\"data row0 col4\" >0.950832</td>\n",
|
|||
|
" <td id=\"T_a7559_row0_col5\" class=\"data row0 col5\" >0.948141</td>\n",
|
|||
|
" <td id=\"T_a7559_row0_col6\" class=\"data row0 col6\" >0.038278</td>\n",
|
|||
|
" <td id=\"T_a7559_row0_col7\" class=\"data row0 col7\" >0.036364</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_a7559_level0_row1\" class=\"row_heading level0 row1\" >knn</th>\n",
|
|||
|
" <td id=\"T_a7559_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_a7559_row1_col1\" class=\"data row1 col1\" >0.117647</td>\n",
|
|||
|
" <td id=\"T_a7559_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_a7559_row1_col3\" class=\"data row1 col3\" >0.120000</td>\n",
|
|||
|
" <td id=\"T_a7559_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_a7559_row1_col5\" class=\"data row1 col5\" >0.912916</td>\n",
|
|||
|
" <td id=\"T_a7559_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_a7559_row1_col7\" class=\"data row1 col7\" >0.118812</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_a7559_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_a7559_row2_col0\" class=\"data row2 col0\" >0.228869</td>\n",
|
|||
|
" <td id=\"T_a7559_row2_col1\" class=\"data row2 col1\" >0.135135</td>\n",
|
|||
|
" <td id=\"T_a7559_row2_col2\" class=\"data row2 col2\" >0.884422</td>\n",
|
|||
|
" <td id=\"T_a7559_row2_col3\" class=\"data row2 col3\" >0.500000</td>\n",
|
|||
|
" <td id=\"T_a7559_row2_col4\" class=\"data row2 col4\" >0.849315</td>\n",
|
|||
|
" <td id=\"T_a7559_row2_col5\" class=\"data row2 col5\" >0.818982</td>\n",
|
|||
|
" <td id=\"T_a7559_row2_col6\" class=\"data row2 col6\" >0.363636</td>\n",
|
|||
|
" <td id=\"T_a7559_row2_col7\" class=\"data row2 col7\" >0.212766</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x23c21f18560>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 92,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
|||
|
" [\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" \"Accuracy_train\",\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_train\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]\n",
|
|||
|
"class_metrics.sort_values(\n",
|
|||
|
" by=\"Accuracy_test\", ascending=False\n",
|
|||
|
").style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Краткий анализ метрик:\n",
|
|||
|
"\n",
|
|||
|
"1. MLP (многослойный перцептрон)\n",
|
|||
|
"\n",
|
|||
|
" Precision (точность) на обучении: 0.40, на тесте: 0.20\n",
|
|||
|
"\n",
|
|||
|
" Recall (полнота) на обучении: 0.02, на тесте: 0.02\n",
|
|||
|
"\n",
|
|||
|
" Accuracy (точность) на обучении: 0.95, на тесте: 0.95\n",
|
|||
|
"\n",
|
|||
|
" F1-метрика на обучении: 0.038, на тесте: 0.037\n",
|
|||
|
"\n",
|
|||
|
" Вывод: высокая точность на обучении и тесте указывает на хорошую способность модели правильно определять общий класс. Однако низкие значения precision, recall и F1-метрики говорят о сильном смещении: модель плохо справляется с выявлением положительных примеров.\n",
|
|||
|
"\n",
|
|||
|
"2. KNN (Метод K-ближайших соседей)\n",
|
|||
|
"\n",
|
|||
|
" Precision на обучении: 1.0, на тесте: 0.118\n",
|
|||
|
"\n",
|
|||
|
" Recall на обучении: 1.0, на тесте: 0.12\n",
|
|||
|
"\n",
|
|||
|
" Accuracy на обучении: 1.0, на тесте: 0.91\n",
|
|||
|
"\n",
|
|||
|
" F1-метрика на обучении: 1.0, на тесте: 0.119\n",
|
|||
|
"\n",
|
|||
|
" Вывод: модель показывает явное переобучение. Она идеально предсказывает на обучающем наборе, но значительно теряет точность на тестовых данных.\n",
|
|||
|
"\n",
|
|||
|
"3. Random Forest (Случайный лес)\n",
|
|||
|
"\n",
|
|||
|
" Precision на обучении: 0.229, на тесте: 0.135\n",
|
|||
|
"\n",
|
|||
|
" Recall на обучении: 0.88, на тесте: 0.50\n",
|
|||
|
"\n",
|
|||
|
" Accuracy на обучении: 0.85, на тесте: 0.82\n",
|
|||
|
"\n",
|
|||
|
" F1-метрика на обучении: 0.364, на тесте: 0.213\n",
|
|||
|
"\n",
|
|||
|
" Вывод: модель по сравнению с остальными вариантами показывает сбалансированные значения метрик, но их сложно назвать хорошими. Так, precision остается достаточно низким, что указывает на необходимость улучшения способности к идентификации положительных примеров.\n",
|
|||
|
"\n",
|
|||
|
"Сравнение с ориентиром.\n",
|
|||
|
"\n",
|
|||
|
"- Baseline Accuracy: 0.52\n",
|
|||
|
"- Baseline Precision: 0.058\n",
|
|||
|
"- Baseline Recall: 0.58\n",
|
|||
|
"- Baseline F1 Score: 0.106\n",
|
|||
|
"\n",
|
|||
|
"Accuracy: все модели (особенно MLP и KNN) значительно превосходят базовую модель по точности. Random Forest также превосходит базовую модель, но не так явно.\n",
|
|||
|
"\n",
|
|||
|
"Precision: все модели лучше, чем базовая модель, хотя точность остается низкой. Особенно низкие значения у KNN и Random Forest.\n",
|
|||
|
"\n",
|
|||
|
"Recall: базовая модель показывает лучший recall, чем MLP и KNN. Это указывает на то, что обе модели (особенно MLP) с трудом находят положительные примеры. Random Forest лучше справляется с этой задачей.\n",
|
|||
|
"\n",
|
|||
|
"F1 Score: Random Forest показывает наилучшую F1-метрику, указывая на баланс между precision и recall, но она все еще значительно ниже желаемого уровня.\n",
|
|||
|
"\n",
|
|||
|
"Выводы о смещении и дисперсии:\n",
|
|||
|
"\n",
|
|||
|
"MLP: модель сильно смещена, поскольку плохо распознает положительные примеры, несмотря на высокую общую точность.\n",
|
|||
|
"\n",
|
|||
|
"KNN: высокая дисперсия, модель сильно переобучена на обучающем наборе и плохо обобщает на тестовом.\n",
|
|||
|
"\n",
|
|||
|
"Random Forest: наиболее сбалансированная модель с умеренным смещением и дисперсией. Она показывает лучший баланс между precision и recall, хотя precision остается невысоким.\n",
|
|||
|
"\n",
|
|||
|
"Заключение:\n",
|
|||
|
"\n",
|
|||
|
"Самой качественной моделью в данном случае можно назвать Random Forest, так как она показывает лучший баланс между различными метриками, но при этом и данная модель далека от идеала."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Регрессия\n",
|
|||
|
"\n",
|
|||
|
"Разделим набор данных на на обучающую и тестовые выборки (80/20). Целевой признак - avg_glucose_level"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 93,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>gender</th>\n",
|
|||
|
" <th>age</th>\n",
|
|||
|
" <th>hypertension</th>\n",
|
|||
|
" <th>heart_disease</th>\n",
|
|||
|
" <th>ever_married</th>\n",
|
|||
|
" <th>work_type</th>\n",
|
|||
|
" <th>Residence_type</th>\n",
|
|||
|
" <th>bmi</th>\n",
|
|||
|
" <th>smoking_status</th>\n",
|
|||
|
" <th>stroke</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>id</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>13276</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>38.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>22.6</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>21346</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>12.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>children</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>17.8</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>59178</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>children</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>22.3</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1679</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>35.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>formerly smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1534</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>61.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>26.1</td>\n",
|
|||
|
" <td>smokes</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>30463</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>29.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>29.4</td>\n",
|
|||
|
" <td>formerly smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>41935</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>34.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>33.9</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>68483</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>60.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>41.2</td>\n",
|
|||
|
" <td>formerly smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>38617</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>28.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Self-employed</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>29.9</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>46527</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>53.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Govt_job</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>41.9</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>4088 rows × 10 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" gender age hypertension heart_disease ever_married work_type \\\n",
|
|||
|
"id \n",
|
|||
|
"13276 Female 38.0 0 0 Yes Private \n",
|
|||
|
"21346 Female 12.0 0 0 No children \n",
|
|||
|
"59178 Female 7.0 0 0 No children \n",
|
|||
|
"1679 Male 35.0 0 0 Yes Private \n",
|
|||
|
"1534 Female 61.0 0 0 Yes Private \n",
|
|||
|
"... ... ... ... ... ... ... \n",
|
|||
|
"30463 Male 29.0 0 0 No Private \n",
|
|||
|
"41935 Male 34.0 0 0 No Private \n",
|
|||
|
"68483 Female 60.0 0 0 Yes Private \n",
|
|||
|
"38617 Male 28.0 0 0 Yes Self-employed \n",
|
|||
|
"46527 Male 53.0 1 1 Yes Govt_job \n",
|
|||
|
"\n",
|
|||
|
" Residence_type bmi smoking_status stroke \n",
|
|||
|
"id \n",
|
|||
|
"13276 Urban 22.6 Unknown 0 \n",
|
|||
|
"21346 Rural 17.8 Unknown 0 \n",
|
|||
|
"59178 Urban 22.3 Unknown 0 \n",
|
|||
|
"1679 Rural NaN formerly smoked 0 \n",
|
|||
|
"1534 Rural 26.1 smokes 0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"30463 Urban 29.4 formerly smoked 0 \n",
|
|||
|
"41935 Rural 33.9 never smoked 0 \n",
|
|||
|
"68483 Urban 41.2 formerly smoked 0 \n",
|
|||
|
"38617 Urban 29.9 never smoked 0 \n",
|
|||
|
"46527 Rural 41.9 never smoked 0 \n",
|
|||
|
"\n",
|
|||
|
"[4088 rows x 10 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"id\n",
|
|||
|
"13276 71.06\n",
|
|||
|
"21346 70.13\n",
|
|||
|
"59178 86.75\n",
|
|||
|
"1679 77.48\n",
|
|||
|
"1534 99.35\n",
|
|||
|
" ... \n",
|
|||
|
"30463 82.93\n",
|
|||
|
"41935 125.29\n",
|
|||
|
"68483 65.38\n",
|
|||
|
"38617 73.98\n",
|
|||
|
"46527 109.51\n",
|
|||
|
"Name: avg_glucose_level, Length: 4088, dtype: float64"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>gender</th>\n",
|
|||
|
" <th>age</th>\n",
|
|||
|
" <th>hypertension</th>\n",
|
|||
|
" <th>heart_disease</th>\n",
|
|||
|
" <th>ever_married</th>\n",
|
|||
|
" <th>work_type</th>\n",
|
|||
|
" <th>Residence_type</th>\n",
|
|||
|
" <th>bmi</th>\n",
|
|||
|
" <th>smoking_status</th>\n",
|
|||
|
" <th>stroke</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>id</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>8385</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>37.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>35.9</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>937</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>children</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3494</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>80.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>26.7</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>23850</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>66.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>33.1</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>31156</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>49.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>29.8</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>71010</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>80.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Self-employed</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>22.8</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>39518</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>20.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Rural</td>\n",
|
|||
|
" <td>20.7</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>7780</th>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>51.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Self-employed</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>30.7</td>\n",
|
|||
|
" <td>never smoked</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>56137</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>62.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Private</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>36.3</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>33175</th>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>57.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Govt_job</td>\n",
|
|||
|
" <td>Urban</td>\n",
|
|||
|
" <td>28.5</td>\n",
|
|||
|
" <td>Unknown</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>1022 rows × 10 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" gender age hypertension heart_disease ever_married work_type \\\n",
|
|||
|
"id \n",
|
|||
|
"8385 Male 37.0 0 0 Yes Private \n",
|
|||
|
"937 Male 7.0 0 0 No children \n",
|
|||
|
"3494 Female 80.0 0 0 Yes Private \n",
|
|||
|
"23850 Male 66.0 0 0 Yes Private \n",
|
|||
|
"31156 Female 49.0 0 0 Yes Private \n",
|
|||
|
"... ... ... ... ... ... ... \n",
|
|||
|
"71010 Female 80.0 0 0 No Self-employed \n",
|
|||
|
"39518 Female 20.0 0 0 No Private \n",
|
|||
|
"7780 Male 51.0 0 0 Yes Self-employed \n",
|
|||
|
"56137 Female 62.0 0 0 Yes Private \n",
|
|||
|
"33175 Female 57.0 0 0 Yes Govt_job \n",
|
|||
|
"\n",
|
|||
|
" Residence_type bmi smoking_status stroke \n",
|
|||
|
"id \n",
|
|||
|
"8385 Urban 35.9 Unknown 0 \n",
|
|||
|
"937 Urban NaN Unknown 0 \n",
|
|||
|
"3494 Rural 26.7 Unknown 0 \n",
|
|||
|
"23850 Urban 33.1 never smoked 0 \n",
|
|||
|
"31156 Urban 29.8 never smoked 0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"71010 Urban 22.8 never smoked 0 \n",
|
|||
|
"39518 Rural 20.7 never smoked 0 \n",
|
|||
|
"7780 Urban 30.7 never smoked 0 \n",
|
|||
|
"56137 Urban 36.3 Unknown 0 \n",
|
|||
|
"33175 Urban 28.5 Unknown 1 \n",
|
|||
|
"\n",
|
|||
|
"[1022 rows x 10 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"id\n",
|
|||
|
"8385 90.78\n",
|
|||
|
"937 87.94\n",
|
|||
|
"3494 102.90\n",
|
|||
|
"23850 103.01\n",
|
|||
|
"31156 105.99\n",
|
|||
|
" ... \n",
|
|||
|
"71010 57.57\n",
|
|||
|
"39518 78.94\n",
|
|||
|
"7780 75.73\n",
|
|||
|
"56137 88.32\n",
|
|||
|
"33175 110.52\n",
|
|||
|
"Name: avg_glucose_level, Length: 1022, dtype: float64"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"features = ['gender', 'age', 'hypertension', 'heart_disease', 'ever_married', 'work_type', 'Residence_type', 'bmi', 'smoking_status', 'stroke']\n",
|
|||
|
"target = 'avg_glucose_level'\n",
|
|||
|
"\n",
|
|||
|
"X_train, X_test, y_train, y_test = train_test_split(df[features], df[target], test_size=0.2, random_state=random_state)\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_train\", X_train)\n",
|
|||
|
"display(\"y_train\", y_train)\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_test\", X_test)\n",
|
|||
|
"display(\"y_test\", y_test)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Выберем ориентир для задачи регрессии. Для этого применим алгоритм правила нуля, т.е. в каждом случае в качестве предсказания выберем среднее значение из области значений целевого признака."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 94,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Baseline RMSE: 44.12711275645952\n",
|
|||
|
"Baseline RMAE: 5.662154850745081\n",
|
|||
|
"Baseline R2: -0.0010729515309222393\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import math\n",
|
|||
|
"from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n",
|
|||
|
"\n",
|
|||
|
"# Базовое предсказание: среднее значение по y_train\n",
|
|||
|
"baseline_predictions = [y_train.mean()] * len(y_test)\n",
|
|||
|
"\n",
|
|||
|
"# Вычисление метрик качества для ориентира\n",
|
|||
|
"baseline_rmse = math.sqrt(\n",
|
|||
|
" mean_squared_error(y_test, baseline_predictions)\n",
|
|||
|
" )\n",
|
|||
|
"baseline_rmae = math.sqrt(\n",
|
|||
|
" mean_absolute_error(y_test, baseline_predictions)\n",
|
|||
|
" )\n",
|
|||
|
"baseline_r2 = r2_score(y_test, baseline_predictions)\n",
|
|||
|
"\n",
|
|||
|
"print('Baseline RMSE:', baseline_rmse)\n",
|
|||
|
"print('Baseline RMAE:', baseline_rmae)\n",
|
|||
|
"print('Baseline R2:', baseline_r2)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Были использованы следующие метрики:\n",
|
|||
|
"\n",
|
|||
|
"- RMSE: корень из MSE. MSE (Mean Squared Error) — среднеквадратическая ошибка, квадрат отклонения между предсказанными и истинными значениями. MSE чувствительна к большим ошибкам, так как отклонения возводятся в квадрат. RMSE также штрафует за большие ошибки, но в отличие от MSE, масштаб ошибки аналогичен исходным данным, что облегчает интерпретацию. Это делает RMSE хорошим выбором для многих практических задач, где важна интерпретируемость результата.\n",
|
|||
|
"- RMAE: корень из MAE. MAE (Mean Absolute Error) — средняя абсолютная ошибка. Она показывает среднее отклонение предсказаний от истинных значений. MAE менее чувствительна к выбросам по сравнению с MSE и RMSE. Это делает её предпочтительным вариантом, когда выбросы присутствуют в данных, но не должны сильно влиять на общую производительность модели.\n",
|
|||
|
"- R2 (коэффициент детерминации) : R2 измеряет, какая доля вариативности зависимой переменной объясняется независимыми переменными в модели. Это хороший способ оценить адекватность модели: близость к 1 говорит о хорошем объяснении данных моделью. R2 лучше всего подходит для сравнения моделей с одинаковыми данными.\n",
|
|||
|
"\n",
|
|||
|
"Таким образом, результаты этих метрик для базового ориентира позволят оценить, насколько лучше (или хуже) модель по сравнению с простым предсказанием среднего значения."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Сформируем конвейер для регрессии"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 95,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"columns_to_drop = []\n",
|
|||
|
"columns_not_to_modify = [\"hypertension\", \"heart_disease\", \"stroke\", \"avg_glucose_level\"]\n",
|
|||
|
"\n",
|
|||
|
"num_columns = [\n",
|
|||
|
" column\n",
|
|||
|
" for column in df.columns\n",
|
|||
|
" if column not in columns_to_drop\n",
|
|||
|
" and column not in columns_not_to_modify\n",
|
|||
|
" and df[column].dtype != \"object\"\n",
|
|||
|
"]\n",
|
|||
|
"\n",
|
|||
|
"cat_columns = [\n",
|
|||
|
" column\n",
|
|||
|
" for column in df.columns\n",
|
|||
|
" if column not in columns_to_drop\n",
|
|||
|
" and column not in columns_not_to_modify\n",
|
|||
|
" and df[column].dtype == \"object\"\n",
|
|||
|
"]\n",
|
|||
|
"\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\"\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"drop_columns = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\",\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"pipeline_end_reg = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" (\"drop_columns\", drop_columns),\n",
|
|||
|
" ]\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Теперь проверим работу конвейера:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 96,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>age</th>\n",
|
|||
|
" <th>bmi</th>\n",
|
|||
|
" <th>gender_Male</th>\n",
|
|||
|
" <th>gender_Other</th>\n",
|
|||
|
" <th>ever_married_Yes</th>\n",
|
|||
|
" <th>work_type_Never_worked</th>\n",
|
|||
|
" <th>work_type_Private</th>\n",
|
|||
|
" <th>work_type_Self-employed</th>\n",
|
|||
|
" <th>work_type_children</th>\n",
|
|||
|
" <th>Residence_type_Urban</th>\n",
|
|||
|
" <th>smoking_status_formerly smoked</th>\n",
|
|||
|
" <th>smoking_status_never smoked</th>\n",
|
|||
|
" <th>smoking_status_smokes</th>\n",
|
|||
|
" <th>hypertension</th>\n",
|
|||
|
" <th>heart_disease</th>\n",
|
|||
|
" <th>stroke</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>id</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>13276</th>\n",
|
|||
|
" <td>-0.236211</td>\n",
|
|||
|
" <td>-0.826056</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>21346</th>\n",
|
|||
|
" <td>-1.386874</td>\n",
|
|||
|
" <td>-1.455413</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>59178</th>\n",
|
|||
|
" <td>-1.608155</td>\n",
|
|||
|
" <td>-0.865391</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1679</th>\n",
|
|||
|
" <td>-0.368980</td>\n",
|
|||
|
" <td>-0.104918</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1534</th>\n",
|
|||
|
" <td>0.781682</td>\n",
|
|||
|
" <td>-0.367150</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>30463</th>\n",
|
|||
|
" <td>-0.634518</td>\n",
|
|||
|
" <td>0.065532</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>41935</th>\n",
|
|||
|
" <td>-0.413236</td>\n",
|
|||
|
" <td>0.655554</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>68483</th>\n",
|
|||
|
" <td>0.737426</td>\n",
|
|||
|
" <td>1.612701</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>38617</th>\n",
|
|||
|
" <td>-0.678774</td>\n",
|
|||
|
" <td>0.131090</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>46527</th>\n",
|
|||
|
" <td>0.427632</td>\n",
|
|||
|
" <td>1.704482</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>4088 rows × 16 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" age bmi gender_Male gender_Other ever_married_Yes \\\n",
|
|||
|
"id \n",
|
|||
|
"13276 -0.236211 -0.826056 0.0 0.0 1.0 \n",
|
|||
|
"21346 -1.386874 -1.455413 0.0 0.0 0.0 \n",
|
|||
|
"59178 -1.608155 -0.865391 0.0 0.0 0.0 \n",
|
|||
|
"1679 -0.368980 -0.104918 1.0 0.0 1.0 \n",
|
|||
|
"1534 0.781682 -0.367150 0.0 0.0 1.0 \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"30463 -0.634518 0.065532 1.0 0.0 0.0 \n",
|
|||
|
"41935 -0.413236 0.655554 1.0 0.0 0.0 \n",
|
|||
|
"68483 0.737426 1.612701 0.0 0.0 1.0 \n",
|
|||
|
"38617 -0.678774 0.131090 1.0 0.0 1.0 \n",
|
|||
|
"46527 0.427632 1.704482 1.0 0.0 1.0 \n",
|
|||
|
"\n",
|
|||
|
" work_type_Never_worked work_type_Private work_type_Self-employed \\\n",
|
|||
|
"id \n",
|
|||
|
"13276 0.0 1.0 0.0 \n",
|
|||
|
"21346 0.0 0.0 0.0 \n",
|
|||
|
"59178 0.0 0.0 0.0 \n",
|
|||
|
"1679 0.0 1.0 0.0 \n",
|
|||
|
"1534 0.0 1.0 0.0 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"30463 0.0 1.0 0.0 \n",
|
|||
|
"41935 0.0 1.0 0.0 \n",
|
|||
|
"68483 0.0 1.0 0.0 \n",
|
|||
|
"38617 0.0 0.0 1.0 \n",
|
|||
|
"46527 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" work_type_children Residence_type_Urban \\\n",
|
|||
|
"id \n",
|
|||
|
"13276 0.0 1.0 \n",
|
|||
|
"21346 1.0 0.0 \n",
|
|||
|
"59178 1.0 1.0 \n",
|
|||
|
"1679 0.0 0.0 \n",
|
|||
|
"1534 0.0 0.0 \n",
|
|||
|
"... ... ... \n",
|
|||
|
"30463 0.0 1.0 \n",
|
|||
|
"41935 0.0 0.0 \n",
|
|||
|
"68483 0.0 1.0 \n",
|
|||
|
"38617 0.0 1.0 \n",
|
|||
|
"46527 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" smoking_status_formerly smoked smoking_status_never smoked \\\n",
|
|||
|
"id \n",
|
|||
|
"13276 0.0 0.0 \n",
|
|||
|
"21346 0.0 0.0 \n",
|
|||
|
"59178 0.0 0.0 \n",
|
|||
|
"1679 1.0 0.0 \n",
|
|||
|
"1534 0.0 0.0 \n",
|
|||
|
"... ... ... \n",
|
|||
|
"30463 1.0 0.0 \n",
|
|||
|
"41935 0.0 1.0 \n",
|
|||
|
"68483 1.0 0.0 \n",
|
|||
|
"38617 0.0 1.0 \n",
|
|||
|
"46527 0.0 1.0 \n",
|
|||
|
"\n",
|
|||
|
" smoking_status_smokes hypertension heart_disease stroke \n",
|
|||
|
"id \n",
|
|||
|
"13276 0.0 0 0 0 \n",
|
|||
|
"21346 0.0 0 0 0 \n",
|
|||
|
"59178 0.0 0 0 0 \n",
|
|||
|
"1679 0.0 0 0 0 \n",
|
|||
|
"1534 1.0 0 0 0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"30463 0.0 0 0 0 \n",
|
|||
|
"41935 0.0 0 0 0 \n",
|
|||
|
"68483 0.0 0 0 0 \n",
|
|||
|
"38617 0.0 0 0 0 \n",
|
|||
|
"46527 0.0 1 1 0 \n",
|
|||
|
"\n",
|
|||
|
"[4088 rows x 16 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 96,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"preprocessing_result = pipeline_end_reg.fit_transform(X_train)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end_reg.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"preprocessed_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Подберем оптимальные гиперпараметры для каждой из выбранных моделей методом поиска по сетке и сформируем их набор.\n",
|
|||
|
"\n",
|
|||
|
"knn -- k-ближайших соседей\n",
|
|||
|
"\n",
|
|||
|
"random_forest -- метод случайного леса (набор деревьев решений)\n",
|
|||
|
"\n",
|
|||
|
"mlp -- многослойный персептрон (нейронная сеть)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Лучшие параметры для knn: {'n_jobs': -1, 'n_neighbors': 30, 'weights': 'uniform'}\n",
|
|||
|
"Лучшие параметры для random_forest: {'criterion': 'squared_error', 'max_depth': 7, 'max_features': 'sqrt', 'n_estimators': 250, 'n_jobs': -1, 'random_state': 9}\n",
|
|||
|
"Лучшие параметры для mlp: {'alpha': np.float64(1e-06), 'early_stopping': False, 'hidden_layer_sizes': np.int64(13), 'max_iter': 1000, 'random_state': 9, 'solver': 'adam'}\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Словарь с вариантами гиперпараметров для каждой модели\n",
|
|||
|
"param_grids = {\n",
|
|||
|
" \"knn\": {\n",
|
|||
|
" \"n_neighbors\": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], \n",
|
|||
|
" \"weights\": ['uniform', 'distance'],\n",
|
|||
|
" \"n_jobs\": [-1]\n",
|
|||
|
" },\n",
|
|||
|
" \"random_forest\": {\n",
|
|||
|
" \"n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
|
|||
|
" \"max_features\": [\"sqrt\", \"log2\", 2],\n",
|
|||
|
" \"max_depth\": [2, 3, 4, 5, 6, 7, 8, 9, 10],\n",
|
|||
|
" \"criterion\": [\"squared_error\", \"absolute_error\", \"poisson\"],\n",
|
|||
|
" \"random_state\": [random_state],\n",
|
|||
|
" \"n_jobs\": [-1]\n",
|
|||
|
" },\n",
|
|||
|
" \"mlp\": {\n",
|
|||
|
" \"solver\": ['adam'], \n",
|
|||
|
" \"max_iter\": [1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000], \n",
|
|||
|
" \"alpha\": 10.0 ** -np.arange(1, 10), \n",
|
|||
|
" \"hidden_layer_sizes\":np.arange(10, 15), \n",
|
|||
|
" \"early_stopping\": [True, False],\n",
|
|||
|
" \"random_state\": [random_state]\n",
|
|||
|
" }\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Создаем экземпляры моделей\n",
|
|||
|
"models = {\n",
|
|||
|
" \"knn\": neighbors.KNeighborsRegressor(),\n",
|
|||
|
" \"random_forest\": ensemble.RandomForestRegressor(),\n",
|
|||
|
" \"mlp\": neural_network.MLPRegressor()\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Словарь для хранения моделей с их лучшими параметрами\n",
|
|||
|
"class_models = {}\n",
|
|||
|
"\n",
|
|||
|
"# Выполнение поиска по сетке для каждой модели\n",
|
|||
|
"for model_name, model in models.items():\n",
|
|||
|
" # Создаем GridSearchCV для текущей модели\n",
|
|||
|
" gs_optimizer = GridSearchCV(estimator=model, param_grid=param_grids[model_name], scoring='neg_mean_squared_error', n_jobs=-1)\n",
|
|||
|
" \n",
|
|||
|
" # Обучаем GridSearchCV\n",
|
|||
|
" gs_optimizer.fit(preprocessed_df, y_train.values.ravel())\n",
|
|||
|
" \n",
|
|||
|
" # Получаем лучшие параметры\n",
|
|||
|
" best_params = gs_optimizer.best_params_\n",
|
|||
|
" print(f\"Лучшие параметры для {model_name}: {best_params}\")\n",
|
|||
|
" \n",
|
|||
|
" class_models[model_name] = {\n",
|
|||
|
" \"model\": model.set_params(**best_params) # Настраиваем модель с лучшими параметрами\n",
|
|||
|
" }"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Далее обучим модели и оценим их качество."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 98,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: knn\n",
|
|||
|
"Model: random_forest\n",
|
|||
|
"Model: mlp\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for model_name in class_models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" \n",
|
|||
|
" model = class_models[model_name][\"model\"]\n",
|
|||
|
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end_reg), (\"model\", model)])\n",
|
|||
|
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
|
|||
|
"\n",
|
|||
|
" y_train_pred = model_pipeline.predict(X_train)\n",
|
|||
|
" y_test_pred = model_pipeline.predict(X_test)\n",
|
|||
|
"\n",
|
|||
|
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
|||
|
" class_models[model_name][\"train_preds\"] = y_train_pred\n",
|
|||
|
" class_models[model_name][\"preds\"] = y_test_pred\n",
|
|||
|
" \n",
|
|||
|
" class_models[model_name][\"RMSE_train\"] = math.sqrt(\n",
|
|||
|
" mean_squared_error(y_train, y_train_pred)\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"RMSE_test\"] = math.sqrt(\n",
|
|||
|
" mean_squared_error(y_test, y_test_pred)\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"RMAE_test\"] = math.sqrt(\n",
|
|||
|
" mean_absolute_error(y_test, y_test_pred)\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"R2_test\"] = r2_score(y_test, y_test_pred)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"RMSE, RMAE, R2:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 99,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_0650d_row0_col0, #T_0650d_row2_col1 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_0650d_row0_col1, #T_0650d_row1_col0 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_0650d_row0_col2, #T_0650d_row2_col3 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_0650d_row0_col3, #T_0650d_row2_col2 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_0650d_row1_col1 {\n",
|
|||
|
" background-color: #20938c;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_0650d_row1_col2 {\n",
|
|||
|
" background-color: #b42e8d;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_0650d_row1_col3 {\n",
|
|||
|
" background-color: #c8437b;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_0650d_row2_col0 {\n",
|
|||
|
" background-color: #73d056;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_0650d\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_0650d_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
|
|||
|
" <th id=\"T_0650d_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
|
|||
|
" <th id=\"T_0650d_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
|
|||
|
" <th id=\"T_0650d_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_0650d_level0_row0\" class=\"row_heading level0 row0\" >mlp</th>\n",
|
|||
|
" <td id=\"T_0650d_row0_col0\" class=\"data row0 col0\" >42.583378</td>\n",
|
|||
|
" <td id=\"T_0650d_row0_col1\" class=\"data row0 col1\" >40.922194</td>\n",
|
|||
|
" <td id=\"T_0650d_row0_col2\" class=\"data row0 col2\" >5.533579</td>\n",
|
|||
|
" <td id=\"T_0650d_row0_col3\" class=\"data row0 col3\" >0.139061</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_0650d_level0_row1\" class=\"row_heading level0 row1\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_0650d_row1_col0\" class=\"data row1 col0\" >40.324186</td>\n",
|
|||
|
" <td id=\"T_0650d_row1_col1\" class=\"data row1 col1\" >41.085298</td>\n",
|
|||
|
" <td id=\"T_0650d_row1_col2\" class=\"data row1 col2\" >5.544678</td>\n",
|
|||
|
" <td id=\"T_0650d_row1_col3\" class=\"data row1 col3\" >0.132184</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_0650d_level0_row2\" class=\"row_heading level0 row2\" >knn</th>\n",
|
|||
|
" <td id=\"T_0650d_row2_col0\" class=\"data row2 col0\" >42.164413</td>\n",
|
|||
|
" <td id=\"T_0650d_row2_col1\" class=\"data row2 col1\" >41.826505</td>\n",
|
|||
|
" <td id=\"T_0650d_row2_col2\" class=\"data row2 col2\" >5.550755</td>\n",
|
|||
|
" <td id=\"T_0650d_row2_col3\" class=\"data row2 col3\" >0.100590</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x23c2371daf0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 99,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"reg_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
|||
|
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
|
|||
|
"]\n",
|
|||
|
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
|
|||
|
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
|
|||
|
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Результаты графиками:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 100,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: knn\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAIjCAYAAAAJLyrXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3xT9foH8M852UnTvRhllb0EBXHjYCgKKo6Lyu+K4uIWEERFFEXFq4jIEHBeRa9XXNd1RUUQEVRQUEDBsnfpXtn7nN8foaHpgKZNmo7P+/XyJT1Jkydpcs73+Y7nK8iyLIOIiIiIiIjqTIx2AERERERERM0NEykiIiIiIqIQMZEiIiIiIiIKERMpIiIiIiKiEDGRIiIiIiIiChETKSIiIiIiohAxkSIiIiIiIgoREykiIiIiIqIQMZEiIiIiIiIKERMpIiJqcQRBwJNPPhny7x05cgSCIODtt9+u1/NOmDABMTEx9fpdIiJqXphIERFRRLz99tsQBAGCIOCnn36qdrssy8jIyIAgCLjmmmuiECEREVH9MZEiIqKI0mq1WLlyZbXjGzZsQE5ODjQaTRSiIiIiahgmUkREFFGjRo3Cxx9/DK/XG3R85cqVOOecc5Cenh6lyIiIiOqPiRQREUXULbfcgpKSEqxduzZwzO1247///S9uvfXWGn/HZrNhxowZyMjIgEajQY8ePbBgwQLIshx0P5fLhenTpyMlJQVGoxFjxoxBTk5OjY954sQJ3HnnnUhLS4NGo0GfPn3w1ltvnTF+j8eDPXv2IC8vL4RXfcqOHTuQkpKCSy+9FFarFQDQqVMnXHPNNfjpp59w7rnnQqvVokuXLvj3v/8d9LsV0yN//vlnPPDAA0hJSYHBYMD111+PoqKiesVDREThwUSKiIgiqlOnTjj//PPx/vvvB4598803MJlMGDduXLX7y7KMMWPGYNGiRbjyyiuxcOFC9OjRAw899BAeeOCBoPveddddWLx4MUaMGIF58+ZBpVLh6quvrvaYBQUFOO+88/Ddd99h8uTJWLJkCbp27YqJEydi8eLFp43/xIkT6NWrF2bNmhXya9+6dSsuv/xyDBw4EN98801QIYoDBw7gxhtvxPDhw/Hiiy8iISEBEyZMwF9//VXtcaZMmYI//vgDc+bMwaRJk/Dll19i8uTJIcdDRETho4x2AERE1PLdeuutmDVrFhwOB3Q6Hd577z0MHToUbdu2rXbf//3vf/j+++/xzDPP4LHHHgMAZGVl4aabbsKSJUswefJkZGZm4o8//sB//vMf/OMf/8Dy5csD97vtttvw559/Bj3mY489Bp/Ph507dyIpKQkAcN999+GWW27Bk08+iXvvvRc6nS6sr/nnn3/GqFGjcPHFF+OTTz6pthZs79692LhxIy6++GIAwM0334yMjAysWLECCxYsCLpvUlIS1qxZA0EQAACSJOGll16CyWRCXFxcWOMmIqK64YgUERFF3M033wyHw4FVq1bBYrFg1apVtU7r+/rrr6FQKDB16tSg4zNmzIAsy/jmm28C9wNQ7X7Tpk0L+lmWZXzyyScYPXo0ZFlGcXFx4L+RI0fCZDJh27ZttcbeqVMnyLIcUkn09evXY+TIkbjiiivw6aef1lhQo3fv3oEkCgBSUlLQo0cPHDp0qNp977nnnkASBQAXX3wxfD4fjh49WueYiIgovDgiRUREEZeSkoJhw4Zh5cqVsNvt8Pl8uPHGG2u879GjR9G2bVsYjcag47169QrcXvF/URSRmZkZdL8ePXoE/VxUVITy8nK8/vrreP3112t8zsLCwnq9rpo4nU5cffXVOOecc/DRRx9Bqaz5UtuhQ4dqxxISElBWVnbG+yYkJABAjfclIqLGwUSKiIgaxa233oq7774b+fn5uOqqqxAfH98ozytJEgBg/PjxuP3222u8T//+/cP2fBqNBqNGjcIXX3yB1atX17pHlkKhqPF41YIaod6XiIgaBxMpIiJqFNdffz3uvfde/PLLL/jwww9rvV/Hjh3x3XffwWKxBI1K7dmzJ3B7xf8lScLBgweDRqH27t0b9HgVFf18Ph+GDRsWzpdUI0EQ8N577+Haa6/FTTfdhG+++QaXXnppxJ+XiIgaF9dIERFRo4iJicErr7yCJ598EqNHj671fqNGjYLP58OyZcuCji9atAiCIOCqq64CgMD/X3rppaD7Va3Cp1AocMMNN+CTTz7Brl27qj3fmcqI16f8uVqtxqefforBgwdj9OjR2LJlS51/l4iImgeOSBERUaOpbWpdZaNHj8Zll12Gxx57DEeOHMFZZ52FNWvW4IsvvsC0adMCa6IGDBiAW265BS+//DJMJhMuuOACrFu3DgcOHKj2mPPmzcP69esxZMgQ3H333ejduzdKS0uxbds2fPfddygtLa01nory57fffntIBSd0Oh1WrVqFyy+/HFdddRU2bNiAvn371vn3iYioaeOIFBERNSmiKOJ///sfpk2bhlWrVmHatGnIzs7GCy+8gIULFwbd96233sLUqVOxevVqPPzww/B4PPjqq6+qPWZaWhq2bNmCO+64A59++mlgL6nS0lI8//zzEXstsbGx+Pbbb5Geno7hw4fXmOQREVHzJMhcqUpERERERBQSjkgRERERERGFiIkUERERERFRiJhIERERERERhYiJFBERERERUYiYSBEREREREYWIiRQREREREVGIuCEvAEmSkJubC6PRCEEQoh0OERERERFFiSzLsFgsaNu2LUSx9nEnJlIAcnNzkZGREe0wiIiIiIioiTh+/Djat29f6+1MpAAYjUYA/jcrNjY2ytEQEREREVG0mM1mZGRkBHKE2jCRAgLT+WJjY5lIERERERHRGZf8sNgEERERERFRiJhIERERERERhYiJFBERERERUYiYSBEREREREYWIiRQREREREVGImEgRERERERGFiIkUERERERFRiJhIERERERERhYiJFBERERERUYiYSBEREREREYWIiRQREREREVGImEgRERERERGFiIkUERERERFRiJhIERERERERhYiJFBERERERUYiYSBEREREREYWIiRQREREREUWFLMs4cOBAtMOoFyZSRERERETU6Hbt2oXLLrsMgwYNQmFhYbTDCRkTKSIiIiIiajQmkwnTpk3DgAEDsGHDBphMJsyaNSvaYYVMGe0AiIiIiIio5ZMkCe+++y4efvjhoBGozMxMjB07NoqR1Q8TKSIiIiIiiqgdO3YgKysLmzZtChzT6XR49NFH8eCDD0Kr1UYxuvphIkVERERERBFhMpnw6KOP4tVXX4UkSYHjY8eOxcKFC9GxY8coRtcwTKSIiIiIiCgiZFnGf//730AS1b17dyxduhQjRoyIcmQNx2ITREREREQUEfHx8Zg/fz4MBgPmzZuHnTt3togkCmAiRUREREREYVBSUoIpU6bgxIkTQcf/7//+D/v378fMmTOhVqujFF34cWofERERERHVm8/nwxtvvIHHHnsMpaWlKCkpwcqVKwO3i6KINm3aRDHCyOCIFBERERER1cvmzZtx7rnnYtKkSSgtLQUArFq1Cnl5eVGOLPKYSBERERERUUgKCwtx55134oILLsC2bdsCx8ePH4+9e/e2yBGoqji1j4iIiIiI6sTr9eKVV17B448/DpPJFDjer18/LF++HBdffHEUo2tcTKSIiIiIiKhOrrvuOnz11VeBn+Pi4jB37lxMmjQJSmXrSi04tY+IiIiIiOrk73//e+DfEyZMwN69ezFlypRWl0QBHJEiIiIiIqIaeDweWCwWJCYmBo7ddNNN+Omnn3DLLbfg/PPPj2J00ccRKSIiIiIiCvLDDz9g4MCBmDhxYtBxQRDw0ksvtfokCmAiRUREREREJ+Xk5OCWW27BZZddhr/++guff/45Vq9eHe2wmiQmUkRERERErZzb7cb8+fPRs2dPfPDBB4HjgwcPRmpqahQja7q4RoqIiIiIqBVbu3YtpkyZgr179waOJSUlYd68ebjzzjshihx7qQnfFSIiIiKiVujYsWO44YYbMGLEiEASJYoi/vGPf2Dfvn246667mES
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1000x600 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: random_forest\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAIjCAYAAAAJLyrXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3hT9f4H8Pc5WU3TNN0to6VQ9kZRRFQQGYKCigsUBcGFDEFUhB8KilcBkSG4ryLXK67ruqIyRMR7BS8ooGDZG0p3m6bZ4/z+CD02dNC0aZO079fz8NB+c5p8cnJyzvl8pyBJkgQiIiIiIiKqMTHYARAREREREYUbJlJERERERER+YiJFRERERETkJyZSREREREREfmIiRURERERE5CcmUkRERERERH5iIkVEREREROQnJlJERERERER+YiJFRERERETkJyZSRERULUEQMH/+fL//7sSJExAEAe+9917AY2oo4fAeXnrpJbRp0wYKhQI9e/YMdjhERE0GEykiojDw3nvvQRAECIKA//73vxUelyQJqampEAQBN954YxAipGDYuHEjnnzySfTr1w+rV6/GCy+8EOyQKsjMzMT8+fNx4sSJYIdCRBRQymAHQERENRcREYG1a9fiqquu8infunUrzpw5A41GE6TIKBh++OEHiKKId955B2q1OtjhVCozMxPPPvssBgwYgPT09GCHQ0QUMGyRIiIKI8OHD8enn34Kl8vlU7527VpceumlSElJCVJkDcNsNgc7hJCSm5sLrVYbsCRKkiRYrdaAPBcRUWPHRIqIKIyMGTMGBQUF2LRpk1zmcDjwr3/9C3fddVelf2M2mzFz5kykpqZCo9GgQ4cOWLJkCSRJ8tnObrdjxowZSExMhF6vx8iRI3HmzJlKn/Ps2bOYMGECkpOTodFo0KVLF7z77rsXjd/pdOLAgQM4d+7cRbcdP348oqKicPToUQwfPhx6vR533303AOA///kPbr/9dqSlpUGj0SA1NRUzZsyokASUPcfZs2dx8803IyoqComJiXj88cfhdrt9ti0uLsb48eNhMBgQExODcePGobi4uNLYfvjhB1x99dXQ6XSIiYnBTTfdhP379/tsM3/+fAiCgEOHDmHs2LEwGAxITEzE008/DUmScPr0adx0002Ijo5GSkoKXn755Yvuk/IEQcDq1athNpvlbp9lY7lcLhcWLFiAjIwMaDQapKenY86cObDb7T7PkZ6ejhtvvBEbNmxA7969odVq8eabb8r7Y/r06fJx07ZtWyxatAgej8fnOT766CNceuml0Ov1iI6ORrdu3bBixQoA3i6pt99+OwDg2muvleP88ccf/XqvREShiIkUEVEYSU9PR9++ffHhhx/KZd999x2MRiNGjx5dYXtJkjBy5EgsW7YM119/PZYuXYoOHTrgiSeewGOPPeaz7f3334/ly5djyJAhWLhwIVQqFW644YYKz5mTk4MrrrgC33//PaZMmYIVK1agbdu2mDhxIpYvX15t/GfPnkWnTp0we/bsGr1fl8uFoUOHIikpCUuWLMGtt94KAPj0009hsVgwadIkrFy5EkOHDsXKlStx7733VngOt9uNoUOHIj4+HkuWLEH//v3x8ssv46233vLZTzfddBPef/99jB07Fs8//zzOnDmDcePGVXi+77//HkOHDkVubi7mz5+Pxx57DNu2bUO/fv0qHQd05513wuPxYOHChejTpw+ef/55LF++HIMHD0aLFi2waNEitG3bFo8//jh++umnGu0XAHj//fdx9dVXQ6PR4P3338f777+Pa665BoD3s3zmmWdwySWXYNmyZejfvz9efPHFSo+RgwcPYsyYMRg8eDBWrFiBnj17wmKxoH///vjnP/+Je++9F6+88gr69euH2bNn+xw3mzZtwpgxYxAbG4tFixZh4cKFGDBgAH7++WcAwDXXXINp06YBAObMmSPH2alTpxq/TyKikCUREVHIW716tQRA2rlzp7Rq1SpJr9dLFotFkiRJuv3226Vrr71WkiRJatWqlXTDDTfIf/fll19KAKTnn3/e5/luu+02SRAE6ciRI5IkSdKePXskANIjjzzis91dd90lAZDmzZsnl02cOFFq1qyZlJ+f77Pt6NGjJYPBIMd1/PhxCYC0evVqeZuysnHjxl30PY8bN04CID311FMVHit7jfJefPFFSRAE6eTJkxWe47nnnvPZtlevXtKll14q/162nxYvXiyXuVwu6eqrr67wHnr27CklJSVJBQUFctnvv/8uiaIo3XvvvXLZvHnzJADSgw8+6POcLVu2lARBkBYuXCiXFxUVSVqttkb7pbxx48ZJOp3Op6zss7z//vt9yh9//HEJgPTDDz/IZa1atZIASOvXr/fZdsGCBZJOp5MOHTrkU/7UU09JCoVCOnXqlCRJkvToo49K0dHRksvlqjLGTz/9VAIgbdmyxa/3RkQU6tgiRUQUZu644w5YrVasW7cOJpMJ69atq7Jb37fffguFQiG3CpSZOXMmJEnCd999J28HoMJ206dP9/ldkiR89tlnGDFiBCRJQn5+vvxv6NChMBqN2LVrV5Wxp6enQ5Ikv6YTnzRpUoUyrVYr/2w2m5Gfn48rr7wSkiRh9+7dFbZ/+OGHfX6/+uqrcezYMfn3b7/9Fkql0ue1FAoFpk6d6vN3586dw549ezB+/HjExcXJ5d27d8fgwYPl/Vje/fff7/OcvXv3hiRJmDhxolweExODDh06+MRUW2UxXNjiOHPmTADAN99841PeunVrDB061Kfs008/xdVXX43Y2Fifz3jQoEFwu91yy1lMTAzMZrNPV1MioqaCs/YREYWZxMREDBo0CGvXroXFYoHb7cZtt91W6bYnT55E8+bNodfrfcrLuladPHlS/l8URWRkZPhs16FDB5/f8/LyUFxcjLfeesuna1x5ubm5tXpflVEqlWjZsmWF8lOnTuGZZ57Bv//9bxQVFfk8ZjQafX6PiIhAYmKiT1lsbKzP3508eRLNmjVDVFSUz3YXvv+y/XVhOeDdpxs2bIDZbIZOp5PL09LSfLYzGAyIiIhAQkJChfKCgoIKz+uvss+ybdu2PuUpKSmIiYmR30OZ1q1bV3iOw4cP448//qiw38qUfcaPPPIIPvnkEwwbNgwtWrTAkCFDcMcdd+D666+v8/sgIgp1TKSIiMLQXXfdhQceeADZ2dkYNmwYYmJiGuR1yyYaGDt2bKXjhwBv60ygaDQaiKJv5wm3243BgwejsLAQs2bNQseOHaHT6XD27FmMHz++wmQICoUiYPHURmWvX1VM0gUTgNSFIAg12q58614Zj8eDwYMH48knn6z0b9q3bw8ASEpKwp49e7BhwwZ89913+O6777B69Wrce++9WLNmTe2DJyIKA0ykiIjC0C233IKHHnoIv/zyCz7++OMqt2vVqhW+//57mEwmn1apAwcOyI+X/e/xeHD06FGf1paDBw/6PF/ZjH5utxuDBg0K5Fuqsb179+LQoUNYs2aNz+QSdele1qpVK2zevBmlpaU+rVIXvv+y/XVhOeDdpwkJCT6tUcFQ9lkePnzYZ1KHnJwcFBcXy++hOhkZGSgtLa3RZ6xWqzFixAiMGDECHo8HjzzyCN588008/fTTaNu2bY0TOiKicMMxUkREYSgqKgqvv/465s+fjxEjRlS53fDhw+F2u7Fq1Sqf8mXLlkEQBAwbNgwA5P9feeUVn+0unIVPoVDg1ltvxWeffYZ9+/ZVeL28vLxq4/Zn+vOqlLXmlG+9kSRJnnK7NoYPHw6Xy4XXX39dLnO73Vi5cqXPds2aNUPPnj2xZs0an6nR9+3bh40bN2L48OG1jiFQymK48LNbunQpAFQ6E+OF7rjjDmzfvh0bNmyo8FhxcbG8jtmFXRFFUZRbJMumWi9LLKuaSp6IKFyxRYqIKExV1bWuvBEjRuDaa6/F//3f/+HEiRPo0aMHNm7ciK+++grTp0+Xx0T17NkTY8aMwWuvvQaj0Ygrr7wSmzdvxpEjRyo858KFC7Flyxb06dMHDzzwADp37ozCwkLs2rUL33//PQoLC6uMp2z683H
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1000x600 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: mlp\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAIjCAYAAAAJLyrXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3hUZdoG8Puc6S2dJJRAIHQBQUFUXLGgKAqr2FBZC9gwgCAq4qrYVgGpAtZVdHdFXVdXP3FtiwgWXFFAQXqTkB6STG9nzvn+mGRkSAKZZCYzSe7fdXlJ3pnMPDOZOec8b3leQVEUBURERERERNRoYrwDICIiIiIiam2YSBEREREREUWIiRQREREREVGEmEgRERERERFFiIkUERERERFRhJhIERERERERRYiJFBERERERUYSYSBEREREREUWIiRQREREREVGEmEgREVGbIwgCHnvssYh/79ChQxAEAa+//nrUY2qMr776CoIg4KuvvorL8xMRUeMxkSIioph4/fXXIQgCBEHAN998U+d2RVGQk5MDQRBw+eWXxyFCIiKipmMiRUREMaXX67F69eo67evXr8eRI0eg0+niEBUREVHzMJEiIqKYGjNmDN59911IkhTWvnr1apx++unIzs6OU2RERERNx0SKiIhi6vrrr8fRo0fxxRdfhNp8Ph/+9a9/4YYbbqj3d5xOJ2bNmoWcnBzodDr06dMHCxcuhKIoYffzer2YOXMmOnToAIvFgnHjxuHIkSP1PmZhYSEmTZqErKws6HQ6nHLKKXjttddOGr/f78euXbtQXFx80vvecsstMJvNOHz4MC6//HKYzWZ07twZK1euBABs27YNF1xwAUwmE7p161bvSN3xzjvvPAwYMAA//fQTzj77bBgMBnTv3h0vvvjiSX+XiIhih4kUERHFVG5uLs466yy89dZbobZPPvkEVqsVEyZMqHN/RVEwbtw4LFmyBJdccgkWL16MPn364P7778e9994bdt/bbrsNS5cuxcUXX4x58+ZBo9Hgsssuq/OYpaWlOPPMM/Hf//4XU6dOxbJly9CzZ09MnjwZS5cuPWH8hYWF6NevH+bMmdOo1xsIBHDppZciJycHCxYsQG5uLqZOnYrXX38dl1xyCYYOHYr58+fDYrHgpptuwsGDB0/6mFVVVRgzZgxOP/10LFiwAF26dMGUKVMalQgSEVGMKERERDGwatUqBYCyadMmZcWKFYrFYlFcLpeiKIpyzTXXKOeff76iKIrSrVs35bLLLgv93gcffKAAUJ566qmwx7v66qsVQRCUffv2KYqiKFu3blUAKHfffXfY/W644QYFgDJ37txQ2+TJk5WOHTsqFRUVYfedMGGCkpycHIrr4MGDCgBl1apVofvUtt18880nfc0333yzAkB5+umnQ21VVVWKwWBQBEFQ3n777VD7rl276sS5bt06BYCybt26UNvIkSMVAMqiRYtCbV6vVxk8eLCSmZmp+Hy+k8ZFRETRxxEpIiKKuWuvvRZutxtr1qyB3W7HmjVrGpzW95///AcqlQrTp08Pa581axYURcEnn3wSuh+AOvebMWNG2M+KouC9997D2LFjoSgKKioqQv+NHj0aVqsVmzdvbjD23NxcKIoSUUn02267LfTvlJQU9OnTByaTCddee22ovU+fPkhJScGBAwdO+nhqtRp33nln6GetVos777wTZWVl+OmnnxodFxERRY863gEQEVHb16FDB4waNQqrV6+Gy+VCIBDA1VdfXe99f/vtN3Tq1AkWiyWsvV+/fqHba/8viiLy8vLC7tenT5+wn8vLy1FdXY2XX34ZL7/8cr3PWVZW1qTXVR+9Xo8OHTqEtSUnJ6NLly4QBKFOe1VV1Ukfs1OnTjCZTGFtvXv3BhDc++rMM89sZtRERBQpJlJERNQibrjhBtx+++0oKSnBpZdeipSUlBZ5XlmWAQATJ07EzTffXO99Bg0aFLXnU6lUEbUrxxXQICKi1oGJFBERtYgrr7wSd955J77//nu88847Dd6vW7du+O9//wu73R42KrVr167Q7bX/l2UZ+/fvDxuF2r17d9jj1Vb0CwQCGDVqVDRfUospKiqC0+kMG5Xas2cPgODUQyIianlcI0VERC3CbDbjhRdewGOPPYaxY8c2eL8xY8YgEAhgxYoVYe1LliyBIAi49NJLASD0/+eeey7sfsdX4VOpVLjqqqvw3nvvYfv27XWer7y8/IRxR1L+PFYkScJLL70U+tnn8+Gll15Chw4dcPrpp8ctLiKi9owjUkRE1GIamlp3rLFjx+L888/Hn//8Zxw6dAinnnoqPv/8c3z44YeYMWNGaE3U4MGDcf311+P555+H1WrF2WefjbVr12Lfvn11HnPevHlYt24dhg8fjttvvx39+/dHZWUlNm/ejP/+97+orKxsMJ7a8uc333xzRAUnoqlTp06YP38+Dh06hN69e+Odd97B1q1b8fLLL0Oj0cQlJiKi9o6JFBERJRRRFPF///d/ePTRR/HOO+9g1apVyM3NxbPPPotZs2aF3fe1115Dhw4d8Oabb+KDDz7ABRdcgI8//hg5OTlh98vKysIPP/yAJ554Au+//z6ef/55pKen45RTTsH8+fNb8uU1SWpqKt544w1MmzYNr7zyCrKysrBixQrcfvvt8Q6NiKjdEhSuciUiIkpY5513HioqKuqdlkhERPHDNVJEREREREQRYiJFREREREQUISZSREREREREEeIaKSIiIiIioghxRIqIiIiIiChCTKSIiIiIiIgixH2kAMiyjKKiIlgsFgiCEO9wiIiIiIgoThRFgd1uR6dOnSCKDY87MZECUFRUVGfzRiIiIiIiar8KCgrQpUuXBm9nIgXAYrEACL5ZSUlJcY6GiIiIiIjixWazIScnJ5QjNISJFBCazpeUlMREioiIiIiITrrkh8UmiIiIiIiIIsREioiIiIiIKEJMpIiIiIiIiCLERIqIiIiIiChCTKSIiIiIiIgixESKiIiIiIgoQkykiIiIiIiIIsREioiIiIiIKEJMpIiIiIiIiCLERIqIiIiIiChCTKSIiIiIiIgixESKiIiIiIgoQkykiIiIiIiIIsREioiIiIiIKEJMpIiIiIiIiCLERIqIiIiIiChCTKSIiIiIiCguFEXBvn374h1GkzCRIiIiIiKiFrd9+3acf/75GDp0KMrKyuIdTsSYSBERERERUYuxWq2YMWMGBg8ejPXr18NqtWLOnDnxDiti6ngHQEREREREbZ8sy/j73/+OBx54IGwEKi8vD+PHj49jZE3DRIqIiIiIiGJq69atyM/Px3fffRdqMxgMeOihh3DfffdBr9fHMbqmYSJFREREREQxYbVa8dBDD+HFF1+ELMuh9vHjx2Px4sXo1q1bHKNrHiZSREREREQUE4qi4F//+lcoierduzeWL1+Oiy++OM6RNR+LTRARERERUUykpKRgwYIFMJlMmDdvHrZt29YmkiiAiRQREREREUXB0aNHMW3aNBQWFoa1/+lPf8LevXsxe/ZsaLXaOEUXfZzaR0RERERETRYIBPDKK6/gz3/+MyorK3H06FGsXr06dLsoiujYsWMcI4wNjkgREREREVGTbNy4EWeccQamTJmCyspKAMCaNWtQXFwc58hij4kUERERERFFpKysDJMmTcLZZ5+NzZs3h9onTpyI3bt3t8kRqONxah8RERERETWKJEl44YUX8Mgjj8BqtYbaBw4ciJUrV+IPf/hDHKNrWUykiIiIiIioUa644gp8/PHHoZ+Tk5Px5JNPYsqUKVCr21dqwal9RERERETUKDfddFPo37fccgt2796NadOmtbskCuCIFBERERER1cPv98NutyMtLS3Uds011+Cbb77B9ddfj7POOiuO0cUfR6SIiIiIiCjMV199hSFDhmDy5Mlh7YIg4Lnnnmv3SRTARIqIiIiIiGocOXIE119/Pc4//3z8+uuv+OCDD/Dpp5/GO6yExESKiIiIiKid8/l8WLBgAfr27Yu333471D5s2DBkZmbGMbLExTVSRERERETt2BdffIFp06Zh9+7dobb09HTMmzc
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1000x600 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Создаем графики для всех моделей\n",
|
|||
|
"for model_name, model_data in class_models.items():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" y_pred = model_data[\"preds\"]\n",
|
|||
|
" plt.figure(figsize=(10, 6))\n",
|
|||
|
" plt.scatter(y_test, y_pred, alpha=0.5)\n",
|
|||
|
" plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2)\n",
|
|||
|
" plt.xlabel('Фактический уровень глюкозы')\n",
|
|||
|
" plt.ylabel('Прогнозируемый уровень глюкозы')\n",
|
|||
|
" plt.title(f\"Model: {model_name}\")\n",
|
|||
|
" plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"На представленных графиках можно заметить, что модели в целом не демонстрируют высокого качества. Визуализация их предсказаний показывает сильное рассеивание вокруг идеальной линии y = x, что указывает на значительные отклонения предсказаний от фактических значений.\n",
|
|||
|
"\n",
|
|||
|
"Тем не менее ориентир, хоть возможно и не столь значительно, каждая из моделей превосходит по всем показателям. Особенно заметное улучшение в \n",
|
|||
|
"R2, которая переходит из отрицательного значения в положительное, что говорит о том, что модели хотя бы частично объясняют дисперсию данных. \n",
|
|||
|
"\n",
|
|||
|
"Кроме того, можно сказать, что все модели имеет умеренную дисперсию и не сильно подвержены переобучению, потому что разница между RMSE на обучении и тесте незначительна.\n",
|
|||
|
"\n",
|
|||
|
"Итоговые выводы:\n",
|
|||
|
"- Наиболее качественная модель: MLP, так как она показывает наименьшее значение RMSE и наибольшее значение R2, что указывает на лучшую точность и объяснение дисперсии целевой переменной.\n",
|
|||
|
"\n",
|
|||
|
"- Random Forest: Близок по производительности к MLP, с чуть большим RMSE, но является более устойчивой моделью с небольшими отклонениями между обучением и тестом.\n",
|
|||
|
"\n",
|
|||
|
"- KNN: Худшая модель, демонстрирующая наибольшие ошибки и низкое R2, что указывает на необходимость улучшения или использования другой модели для данной задачи."
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "aimenv",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.12.5"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|