5161 lines
594 KiB
Plaintext
5161 lines
594 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Лабораторная 4\n",
|
|||
|
"Датасет: Набор данных для анализа и прогнозирования сердечного приступа"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 54,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Index(['State', 'Sex', 'GeneralHealth', 'PhysicalHealthDays',\n",
|
|||
|
" 'MentalHealthDays', 'LastCheckupTime', 'PhysicalActivities',\n",
|
|||
|
" 'SleepHours', 'RemovedTeeth', 'HadHeartAttack', 'HadAngina',\n",
|
|||
|
" 'HadStroke', 'HadAsthma', 'HadSkinCancer', 'HadCOPD',\n",
|
|||
|
" 'HadDepressiveDisorder', 'HadKidneyDisease', 'HadArthritis',\n",
|
|||
|
" 'HadDiabetes', 'DeafOrHardOfHearing', 'BlindOrVisionDifficulty',\n",
|
|||
|
" 'DifficultyConcentrating', 'DifficultyWalking',\n",
|
|||
|
" 'DifficultyDressingBathing', 'DifficultyErrands', 'SmokerStatus',\n",
|
|||
|
" 'ECigaretteUsage', 'ChestScan', 'RaceEthnicityCategory', 'AgeCategory',\n",
|
|||
|
" 'HeightInMeters', 'WeightInKilograms', 'BMI', 'AlcoholDrinkers',\n",
|
|||
|
" 'HIVTesting', 'FluVaxLast12', 'PneumoVaxEver', 'TetanusLast10Tdap',\n",
|
|||
|
" 'HighRiskLastYear', 'CovidPos'],\n",
|
|||
|
" dtype='object')\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"from sklearn import set_config\n",
|
|||
|
"\n",
|
|||
|
"set_config(transform_output=\"pandas\")\n",
|
|||
|
"df = pd.read_csv(\"csv\\\\heart_2022_no_nans.csv\")\n",
|
|||
|
"print(df.columns)\n",
|
|||
|
"map_heart_disease_to_int = {'No': 0, 'Yes': 1}\n",
|
|||
|
"\n",
|
|||
|
"TARGET_COLUMN_NAME_CLASSIFICATION = 'HadHeartAttack'\n",
|
|||
|
"\n",
|
|||
|
"df[TARGET_COLUMN_NAME_CLASSIFICATION] = df[TARGET_COLUMN_NAME_CLASSIFICATION].map(map_heart_disease_to_int).astype('int32')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Классификация"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Бизнес цель 1: \n",
|
|||
|
"Предсказание сердечного приступа (HadHeartAttack) на основе других факторов."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Формируем выборки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 55,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>State</th>\n",
|
|||
|
" <th>Sex</th>\n",
|
|||
|
" <th>GeneralHealth</th>\n",
|
|||
|
" <th>PhysicalHealthDays</th>\n",
|
|||
|
" <th>MentalHealthDays</th>\n",
|
|||
|
" <th>LastCheckupTime</th>\n",
|
|||
|
" <th>PhysicalActivities</th>\n",
|
|||
|
" <th>SleepHours</th>\n",
|
|||
|
" <th>RemovedTeeth</th>\n",
|
|||
|
" <th>HadHeartAttack</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>HeightInMeters</th>\n",
|
|||
|
" <th>WeightInKilograms</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>AlcoholDrinkers</th>\n",
|
|||
|
" <th>HIVTesting</th>\n",
|
|||
|
" <th>FluVaxLast12</th>\n",
|
|||
|
" <th>PneumoVaxEver</th>\n",
|
|||
|
" <th>TetanusLast10Tdap</th>\n",
|
|||
|
" <th>HighRiskLastYear</th>\n",
|
|||
|
" <th>CovidPos</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>6432</th>\n",
|
|||
|
" <td>Arizona</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Very good</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>Within past 5 years (2 years but less than 5 y...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>8.0</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.88</td>\n",
|
|||
|
" <td>77.11</td>\n",
|
|||
|
" <td>21.83</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>61767</th>\n",
|
|||
|
" <td>Indiana</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Very good</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>6.0</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.73</td>\n",
|
|||
|
" <td>77.11</td>\n",
|
|||
|
" <td>25.85</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received Tdap</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>102005</th>\n",
|
|||
|
" <td>Michigan</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Very good</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.85</td>\n",
|
|||
|
" <td>83.46</td>\n",
|
|||
|
" <td>24.28</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received Tdap</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>183791</th>\n",
|
|||
|
" <td>South Dakota</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Good</td>\n",
|
|||
|
" <td>10.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.75</td>\n",
|
|||
|
" <td>81.65</td>\n",
|
|||
|
" <td>26.58</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received tetanus shot but not sure what type</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>230656</th>\n",
|
|||
|
" <td>West Virginia</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Good</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>8.0</td>\n",
|
|||
|
" <td>6 or more, but not all</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.55</td>\n",
|
|||
|
" <td>68.04</td>\n",
|
|||
|
" <td>28.34</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>93877</th>\n",
|
|||
|
" <td>Maryland</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Very good</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>12.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>6.0</td>\n",
|
|||
|
" <td>6 or more, but not all</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.65</td>\n",
|
|||
|
" <td>113.40</td>\n",
|
|||
|
" <td>41.60</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received Tdap</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>117856</th>\n",
|
|||
|
" <td>Missouri</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Good</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>8.0</td>\n",
|
|||
|
" <td>1 to 5</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.80</td>\n",
|
|||
|
" <td>117.93</td>\n",
|
|||
|
" <td>36.26</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received tetanus shot but not sure what type</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>41922</th>\n",
|
|||
|
" <td>Georgia</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Very good</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.78</td>\n",
|
|||
|
" <td>113.40</td>\n",
|
|||
|
" <td>35.87</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received tetanus shot but not sure what type</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>98221</th>\n",
|
|||
|
" <td>Massachusetts</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Good</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>20.0</td>\n",
|
|||
|
" <td>Within past 2 years (1 year but less than 2 ye...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.70</td>\n",
|
|||
|
" <td>90.72</td>\n",
|
|||
|
" <td>31.32</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes, received Tdap</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>151717</th>\n",
|
|||
|
" <td>New York</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Very good</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.73</td>\n",
|
|||
|
" <td>68.95</td>\n",
|
|||
|
" <td>23.11</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received Tdap</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>196817 rows × 40 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" State Sex GeneralHealth PhysicalHealthDays \\\n",
|
|||
|
"6432 Arizona Male Very good 0.0 \n",
|
|||
|
"61767 Indiana Female Very good 0.0 \n",
|
|||
|
"102005 Michigan Male Very good 0.0 \n",
|
|||
|
"183791 South Dakota Female Good 10.0 \n",
|
|||
|
"230656 West Virginia Female Good 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"93877 Maryland Female Very good 0.0 \n",
|
|||
|
"117856 Missouri Male Good 0.0 \n",
|
|||
|
"41922 Georgia Male Very good 0.0 \n",
|
|||
|
"98221 Massachusetts Female Good 5.0 \n",
|
|||
|
"151717 New York Male Very good 2.0 \n",
|
|||
|
"\n",
|
|||
|
" MentalHealthDays LastCheckupTime \\\n",
|
|||
|
"6432 5.0 Within past 5 years (2 years but less than 5 y... \n",
|
|||
|
"61767 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"102005 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"183791 5.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"230656 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"... ... ... \n",
|
|||
|
"93877 12.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"117856 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"41922 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"98221 20.0 Within past 2 years (1 year but less than 2 ye... \n",
|
|||
|
"151717 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"\n",
|
|||
|
" PhysicalActivities SleepHours RemovedTeeth HadHeartAttack \\\n",
|
|||
|
"6432 Yes 8.0 None of them 0 \n",
|
|||
|
"61767 Yes 6.0 None of them 0 \n",
|
|||
|
"102005 Yes 7.0 None of them 0 \n",
|
|||
|
"183791 Yes 7.0 None of them 0 \n",
|
|||
|
"230656 No 8.0 6 or more, but not all 0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"93877 No 6.0 6 or more, but not all 0 \n",
|
|||
|
"117856 Yes 8.0 1 to 5 0 \n",
|
|||
|
"41922 Yes 7.0 None of them 0 \n",
|
|||
|
"98221 No 5.0 None of them 0 \n",
|
|||
|
"151717 Yes 7.0 None of them 0 \n",
|
|||
|
"\n",
|
|||
|
" ... HeightInMeters WeightInKilograms BMI AlcoholDrinkers \\\n",
|
|||
|
"6432 ... 1.88 77.11 21.83 Yes \n",
|
|||
|
"61767 ... 1.73 77.11 25.85 Yes \n",
|
|||
|
"102005 ... 1.85 83.46 24.28 Yes \n",
|
|||
|
"183791 ... 1.75 81.65 26.58 No \n",
|
|||
|
"230656 ... 1.55 68.04 28.34 No \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"93877 ... 1.65 113.40 41.60 No \n",
|
|||
|
"117856 ... 1.80 117.93 36.26 No \n",
|
|||
|
"41922 ... 1.78 113.40 35.87 Yes \n",
|
|||
|
"98221 ... 1.70 90.72 31.32 Yes \n",
|
|||
|
"151717 ... 1.73 68.95 23.11 Yes \n",
|
|||
|
"\n",
|
|||
|
" HIVTesting FluVaxLast12 PneumoVaxEver \\\n",
|
|||
|
"6432 Yes Yes No \n",
|
|||
|
"61767 Yes No No \n",
|
|||
|
"102005 No No No \n",
|
|||
|
"183791 Yes No No \n",
|
|||
|
"230656 No No No \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"93877 No No No \n",
|
|||
|
"117856 No No No \n",
|
|||
|
"41922 No Yes No \n",
|
|||
|
"98221 Yes No Yes \n",
|
|||
|
"151717 Yes Yes No \n",
|
|||
|
"\n",
|
|||
|
" TetanusLast10Tdap HighRiskLastYear \\\n",
|
|||
|
"6432 No, did not receive any tetanus shot in the pa... No \n",
|
|||
|
"61767 Yes, received Tdap No \n",
|
|||
|
"102005 Yes, received Tdap No \n",
|
|||
|
"183791 Yes, received tetanus shot but not sure what type No \n",
|
|||
|
"230656 No, did not receive any tetanus shot in the pa... No \n",
|
|||
|
"... ... ... \n",
|
|||
|
"93877 Yes, received Tdap No \n",
|
|||
|
"117856 Yes, received tetanus shot but not sure what type No \n",
|
|||
|
"41922 Yes, received tetanus shot but not sure what type No \n",
|
|||
|
"98221 Yes, received Tdap No \n",
|
|||
|
"151717 Yes, received Tdap No \n",
|
|||
|
"\n",
|
|||
|
" CovidPos \n",
|
|||
|
"6432 Yes \n",
|
|||
|
"61767 No \n",
|
|||
|
"102005 Yes \n",
|
|||
|
"183791 No \n",
|
|||
|
"230656 No \n",
|
|||
|
"... ... \n",
|
|||
|
"93877 Yes \n",
|
|||
|
"117856 Yes \n",
|
|||
|
"41922 No \n",
|
|||
|
"98221 No \n",
|
|||
|
"151717 Yes \n",
|
|||
|
"\n",
|
|||
|
"[196817 rows x 40 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>HadHeartAttack</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>6432</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>61767</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>102005</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>183791</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>230656</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>93877</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>117856</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>41922</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>98221</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>151717</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>196817 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" HadHeartAttack\n",
|
|||
|
"6432 0\n",
|
|||
|
"61767 0\n",
|
|||
|
"102005 0\n",
|
|||
|
"183791 0\n",
|
|||
|
"230656 0\n",
|
|||
|
"... ...\n",
|
|||
|
"93877 0\n",
|
|||
|
"117856 0\n",
|
|||
|
"41922 0\n",
|
|||
|
"98221 0\n",
|
|||
|
"151717 0\n",
|
|||
|
"\n",
|
|||
|
"[196817 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>State</th>\n",
|
|||
|
" <th>Sex</th>\n",
|
|||
|
" <th>GeneralHealth</th>\n",
|
|||
|
" <th>PhysicalHealthDays</th>\n",
|
|||
|
" <th>MentalHealthDays</th>\n",
|
|||
|
" <th>LastCheckupTime</th>\n",
|
|||
|
" <th>PhysicalActivities</th>\n",
|
|||
|
" <th>SleepHours</th>\n",
|
|||
|
" <th>RemovedTeeth</th>\n",
|
|||
|
" <th>HadHeartAttack</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>HeightInMeters</th>\n",
|
|||
|
" <th>WeightInKilograms</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>AlcoholDrinkers</th>\n",
|
|||
|
" <th>HIVTesting</th>\n",
|
|||
|
" <th>FluVaxLast12</th>\n",
|
|||
|
" <th>PneumoVaxEver</th>\n",
|
|||
|
" <th>TetanusLast10Tdap</th>\n",
|
|||
|
" <th>HighRiskLastYear</th>\n",
|
|||
|
" <th>CovidPos</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>108080</th>\n",
|
|||
|
" <td>Minnesota</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Very good</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.68</td>\n",
|
|||
|
" <td>81.65</td>\n",
|
|||
|
" <td>29.05</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received tetanus shot, but not Tdap</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>109629</th>\n",
|
|||
|
" <td>Minnesota</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Very good</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>15.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>6.0</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.68</td>\n",
|
|||
|
" <td>99.79</td>\n",
|
|||
|
" <td>35.51</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received tetanus shot but not sure what type</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>24640</th>\n",
|
|||
|
" <td>Connecticut</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Good</td>\n",
|
|||
|
" <td>15.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" <td>6 or more, but not all</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.70</td>\n",
|
|||
|
" <td>72.57</td>\n",
|
|||
|
" <td>25.06</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes, received tetanus shot but not sure what type</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>12715</th>\n",
|
|||
|
" <td>Arkansas</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Good</td>\n",
|
|||
|
" <td>8.0</td>\n",
|
|||
|
" <td>30.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" <td>1 to 5</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.63</td>\n",
|
|||
|
" <td>86.18</td>\n",
|
|||
|
" <td>32.61</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received Tdap</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>162549</th>\n",
|
|||
|
" <td>Ohio</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Excellent</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>4.0</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.60</td>\n",
|
|||
|
" <td>81.19</td>\n",
|
|||
|
" <td>31.71</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received Tdap</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Tested positive using home test without a heal...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>187130</th>\n",
|
|||
|
" <td>South Dakota</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Poor</td>\n",
|
|||
|
" <td>30.0</td>\n",
|
|||
|
" <td>30.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>4.0</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.83</td>\n",
|
|||
|
" <td>97.98</td>\n",
|
|||
|
" <td>29.29</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>38512</th>\n",
|
|||
|
" <td>Florida</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Excellent</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past 5 years (2 years but less than 5 y...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>8.0</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.83</td>\n",
|
|||
|
" <td>104.33</td>\n",
|
|||
|
" <td>31.19</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received tetanus shot but not sure what type</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>125776</th>\n",
|
|||
|
" <td>Nebraska</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Fair</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>6.0</td>\n",
|
|||
|
" <td>1 to 5</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.73</td>\n",
|
|||
|
" <td>92.99</td>\n",
|
|||
|
" <td>31.17</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes, received tetanus shot but not sure what type</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>33614</th>\n",
|
|||
|
" <td>Florida</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Good</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.60</td>\n",
|
|||
|
" <td>65.77</td>\n",
|
|||
|
" <td>25.69</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes, received Tdap</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>223067</th>\n",
|
|||
|
" <td>Washington</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Excellent</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>Within past 2 years (1 year but less than 2 ye...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" <td>1 to 5</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.75</td>\n",
|
|||
|
" <td>70.00</td>\n",
|
|||
|
" <td>22.86</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received Tdap</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>49205 rows × 40 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" State Sex GeneralHealth PhysicalHealthDays \\\n",
|
|||
|
"108080 Minnesota Female Very good 0.0 \n",
|
|||
|
"109629 Minnesota Female Very good 1.0 \n",
|
|||
|
"24640 Connecticut Male Good 15.0 \n",
|
|||
|
"12715 Arkansas Female Good 8.0 \n",
|
|||
|
"162549 Ohio Female Excellent 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"187130 South Dakota Male Poor 30.0 \n",
|
|||
|
"38512 Florida Male Excellent 0.0 \n",
|
|||
|
"125776 Nebraska Male Fair 1.0 \n",
|
|||
|
"33614 Florida Female Good 0.0 \n",
|
|||
|
"223067 Washington Male Excellent 0.0 \n",
|
|||
|
"\n",
|
|||
|
" MentalHealthDays LastCheckupTime \\\n",
|
|||
|
"108080 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"109629 15.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"24640 5.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"12715 30.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"162549 7.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"... ... ... \n",
|
|||
|
"187130 30.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"38512 0.0 Within past 5 years (2 years but less than 5 y... \n",
|
|||
|
"125776 2.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"33614 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"223067 2.0 Within past 2 years (1 year but less than 2 ye... \n",
|
|||
|
"\n",
|
|||
|
" PhysicalActivities SleepHours RemovedTeeth HadHeartAttack \\\n",
|
|||
|
"108080 Yes 7.0 None of them 0 \n",
|
|||
|
"109629 Yes 6.0 None of them 0 \n",
|
|||
|
"24640 Yes 7.0 6 or more, but not all 0 \n",
|
|||
|
"12715 Yes 7.0 1 to 5 0 \n",
|
|||
|
"162549 Yes 4.0 None of them 0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"187130 No 4.0 None of them 0 \n",
|
|||
|
"38512 Yes 8.0 None of them 0 \n",
|
|||
|
"125776 No 6.0 1 to 5 0 \n",
|
|||
|
"33614 Yes 7.0 None of them 0 \n",
|
|||
|
"223067 Yes 7.0 1 to 5 0 \n",
|
|||
|
"\n",
|
|||
|
" ... HeightInMeters WeightInKilograms BMI AlcoholDrinkers \\\n",
|
|||
|
"108080 ... 1.68 81.65 29.05 Yes \n",
|
|||
|
"109629 ... 1.68 99.79 35.51 Yes \n",
|
|||
|
"24640 ... 1.70 72.57 25.06 Yes \n",
|
|||
|
"12715 ... 1.63 86.18 32.61 Yes \n",
|
|||
|
"162549 ... 1.60 81.19 31.71 Yes \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"187130 ... 1.83 97.98 29.29 Yes \n",
|
|||
|
"38512 ... 1.83 104.33 31.19 Yes \n",
|
|||
|
"125776 ... 1.73 92.99 31.17 No \n",
|
|||
|
"33614 ... 1.60 65.77 25.69 Yes \n",
|
|||
|
"223067 ... 1.75 70.00 22.86 Yes \n",
|
|||
|
"\n",
|
|||
|
" HIVTesting FluVaxLast12 PneumoVaxEver \\\n",
|
|||
|
"108080 No No No \n",
|
|||
|
"109629 No Yes No \n",
|
|||
|
"24640 Yes Yes Yes \n",
|
|||
|
"12715 Yes Yes No \n",
|
|||
|
"162549 Yes Yes No \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"187130 No No No \n",
|
|||
|
"38512 No No No \n",
|
|||
|
"125776 Yes No Yes \n",
|
|||
|
"33614 No No Yes \n",
|
|||
|
"223067 Yes Yes No \n",
|
|||
|
"\n",
|
|||
|
" TetanusLast10Tdap HighRiskLastYear \\\n",
|
|||
|
"108080 Yes, received tetanus shot, but not Tdap No \n",
|
|||
|
"109629 Yes, received tetanus shot but not sure what type No \n",
|
|||
|
"24640 Yes, received tetanus shot but not sure what type No \n",
|
|||
|
"12715 Yes, received Tdap No \n",
|
|||
|
"162549 Yes, received Tdap Yes \n",
|
|||
|
"... ... ... \n",
|
|||
|
"187130 No, did not receive any tetanus shot in the pa... No \n",
|
|||
|
"38512 Yes, received tetanus shot but not sure what type No \n",
|
|||
|
"125776 Yes, received tetanus shot but not sure what type No \n",
|
|||
|
"33614 Yes, received Tdap No \n",
|
|||
|
"223067 Yes, received Tdap No \n",
|
|||
|
"\n",
|
|||
|
" CovidPos \n",
|
|||
|
"108080 Yes \n",
|
|||
|
"109629 No \n",
|
|||
|
"24640 No \n",
|
|||
|
"12715 No \n",
|
|||
|
"162549 Tested positive using home test without a heal... \n",
|
|||
|
"... ... \n",
|
|||
|
"187130 No \n",
|
|||
|
"38512 No \n",
|
|||
|
"125776 Yes \n",
|
|||
|
"33614 No \n",
|
|||
|
"223067 No \n",
|
|||
|
"\n",
|
|||
|
"[49205 rows x 40 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>HadHeartAttack</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>108080</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>109629</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>24640</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>12715</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>162549</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>187130</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>38512</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>125776</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>33614</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>223067</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>49205 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" HadHeartAttack\n",
|
|||
|
"108080 0\n",
|
|||
|
"109629 0\n",
|
|||
|
"24640 0\n",
|
|||
|
"12715 0\n",
|
|||
|
"162549 0\n",
|
|||
|
"... ...\n",
|
|||
|
"187130 0\n",
|
|||
|
"38512 0\n",
|
|||
|
"125776 0\n",
|
|||
|
"33614 0\n",
|
|||
|
"223067 0\n",
|
|||
|
"\n",
|
|||
|
"[49205 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from typing import Tuple\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"from pandas import DataFrame\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"\n",
|
|||
|
"def split_stratified_into_train_val_test(\n",
|
|||
|
" df_input,\n",
|
|||
|
" stratify_colname=\"y\",\n",
|
|||
|
" frac_train=0.6,\n",
|
|||
|
" frac_val=0.15,\n",
|
|||
|
" frac_test=0.25,\n",
|
|||
|
" random_state=None,\n",
|
|||
|
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
|||
|
" \n",
|
|||
|
" if frac_train + frac_val + frac_test != 1.0:\n",
|
|||
|
" raise ValueError(\n",
|
|||
|
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
|
|||
|
" % (frac_train, frac_val, frac_test)\n",
|
|||
|
" )\n",
|
|||
|
" if stratify_colname not in df_input.columns:\n",
|
|||
|
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
|
|||
|
" X = df_input # Contains all columns.\n",
|
|||
|
" y = df_input[\n",
|
|||
|
" [stratify_colname]\n",
|
|||
|
" ] # Dataframe of just the column on which to stratify.\n",
|
|||
|
" # Split original dataframe into train and temp dataframes.\n",
|
|||
|
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
|
|||
|
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
|
|||
|
" )\n",
|
|||
|
" if frac_val <= 0:\n",
|
|||
|
" assert len(df_input) == len(df_train) + len(df_temp)\n",
|
|||
|
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
|
|||
|
" # Split the temp dataframe into val and test dataframes.\n",
|
|||
|
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
|
|||
|
" df_val, df_test, y_val, y_test = train_test_split(\n",
|
|||
|
" df_temp,\n",
|
|||
|
" y_temp,\n",
|
|||
|
" stratify=y_temp,\n",
|
|||
|
" test_size=relative_frac_test,\n",
|
|||
|
" random_state=random_state,\n",
|
|||
|
" )\n",
|
|||
|
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
|
|||
|
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
|
|||
|
"\n",
|
|||
|
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
|
|||
|
" df, stratify_colname=TARGET_COLUMN_NAME_CLASSIFICATION, frac_train=0.80, frac_val=0, frac_test=0.20, random_state=9\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_train\", X_train)\n",
|
|||
|
"display(\"y_train\", y_train)\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_test\", X_test)\n",
|
|||
|
"display(\"y_test\", y_test)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 56,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Пропущенные значения по столбцам:\n",
|
|||
|
"State 0\n",
|
|||
|
"Sex 0\n",
|
|||
|
"GeneralHealth 0\n",
|
|||
|
"PhysicalHealthDays 0\n",
|
|||
|
"MentalHealthDays 0\n",
|
|||
|
"LastCheckupTime 0\n",
|
|||
|
"PhysicalActivities 0\n",
|
|||
|
"SleepHours 0\n",
|
|||
|
"RemovedTeeth 0\n",
|
|||
|
"HadHeartAttack 0\n",
|
|||
|
"HadAngina 0\n",
|
|||
|
"HadStroke 0\n",
|
|||
|
"HadAsthma 0\n",
|
|||
|
"HadSkinCancer 0\n",
|
|||
|
"HadCOPD 0\n",
|
|||
|
"HadDepressiveDisorder 0\n",
|
|||
|
"HadKidneyDisease 0\n",
|
|||
|
"HadArthritis 0\n",
|
|||
|
"HadDiabetes 0\n",
|
|||
|
"DeafOrHardOfHearing 0\n",
|
|||
|
"BlindOrVisionDifficulty 0\n",
|
|||
|
"DifficultyConcentrating 0\n",
|
|||
|
"DifficultyWalking 0\n",
|
|||
|
"DifficultyDressingBathing 0\n",
|
|||
|
"DifficultyErrands 0\n",
|
|||
|
"SmokerStatus 0\n",
|
|||
|
"ECigaretteUsage 0\n",
|
|||
|
"ChestScan 0\n",
|
|||
|
"RaceEthnicityCategory 0\n",
|
|||
|
"AgeCategory 0\n",
|
|||
|
"HeightInMeters 0\n",
|
|||
|
"WeightInKilograms 0\n",
|
|||
|
"BMI 0\n",
|
|||
|
"AlcoholDrinkers 0\n",
|
|||
|
"HIVTesting 0\n",
|
|||
|
"FluVaxLast12 0\n",
|
|||
|
"PneumoVaxEver 0\n",
|
|||
|
"TetanusLast10Tdap 0\n",
|
|||
|
"HighRiskLastYear 0\n",
|
|||
|
"CovidPos 0\n",
|
|||
|
"dtype: int64\n",
|
|||
|
"\n",
|
|||
|
"Статистический обзор данных:\n",
|
|||
|
" PhysicalHealthDays MentalHealthDays SleepHours HadHeartAttack \\\n",
|
|||
|
"count 246022.000000 246022.000000 246022.000000 246022.000000 \n",
|
|||
|
"mean 4.119026 4.167140 7.021331 0.054609 \n",
|
|||
|
"std 8.405844 8.102687 1.440681 0.227216 \n",
|
|||
|
"min 0.000000 0.000000 1.000000 0.000000 \n",
|
|||
|
"25% 0.000000 0.000000 6.000000 0.000000 \n",
|
|||
|
"50% 0.000000 0.000000 7.000000 0.000000 \n",
|
|||
|
"75% 3.000000 4.000000 8.000000 0.000000 \n",
|
|||
|
"max 30.000000 30.000000 24.000000 1.000000 \n",
|
|||
|
"\n",
|
|||
|
" HeightInMeters WeightInKilograms BMI \n",
|
|||
|
"count 246022.000000 246022.000000 246022.000000 \n",
|
|||
|
"mean 1.705150 83.615179 28.668136 \n",
|
|||
|
"std 0.106654 21.323156 6.513973 \n",
|
|||
|
"min 0.910000 28.120000 12.020000 \n",
|
|||
|
"25% 1.630000 68.040000 24.270000 \n",
|
|||
|
"50% 1.700000 81.650000 27.460000 \n",
|
|||
|
"75% 1.780000 95.250000 31.890000 \n",
|
|||
|
"max 2.410000 292.570000 97.650000 \n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"null_values = df.isnull().sum()\n",
|
|||
|
"print(\"Пропущенные значения по столбцам:\")\n",
|
|||
|
"print(null_values)\n",
|
|||
|
"\n",
|
|||
|
"stat_summary = df.describe()\n",
|
|||
|
"print(\"\\nСтатистический обзор данных:\")\n",
|
|||
|
"print(stat_summary)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Формируем конвеер для классификации данных и проверка конвеера"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 57,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>PhysicalHealthDays</th>\n",
|
|||
|
" <th>MentalHealthDays</th>\n",
|
|||
|
" <th>SleepHours</th>\n",
|
|||
|
" <th>HadHeartAttack</th>\n",
|
|||
|
" <th>HeightInMeters</th>\n",
|
|||
|
" <th>WeightInKilograms</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>State_Alaska</th>\n",
|
|||
|
" <th>State_Arizona</th>\n",
|
|||
|
" <th>State_Arkansas</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>AlcoholDrinkers_Yes</th>\n",
|
|||
|
" <th>HIVTesting_Yes</th>\n",
|
|||
|
" <th>FluVaxLast12_Yes</th>\n",
|
|||
|
" <th>PneumoVaxEver_Yes</th>\n",
|
|||
|
" <th>TetanusLast10Tdap_Yes, received Tdap</th>\n",
|
|||
|
" <th>TetanusLast10Tdap_Yes, received tetanus shot but not sure what type</th>\n",
|
|||
|
" <th>TetanusLast10Tdap_Yes, received tetanus shot, but not Tdap</th>\n",
|
|||
|
" <th>HighRiskLastYear_Yes</th>\n",
|
|||
|
" <th>CovidPos_Tested positive using home test without a health professional</th>\n",
|
|||
|
" <th>CovidPos_Yes</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>6432</th>\n",
|
|||
|
" <td>-0.490179</td>\n",
|
|||
|
" <td>0.103124</td>\n",
|
|||
|
" <td>0.677965</td>\n",
|
|||
|
" <td>-0.24034</td>\n",
|
|||
|
" <td>1.639362</td>\n",
|
|||
|
" <td>-0.304540</td>\n",
|
|||
|
" <td>-1.051314</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>61767</th>\n",
|
|||
|
" <td>-0.490179</td>\n",
|
|||
|
" <td>-0.513985</td>\n",
|
|||
|
" <td>-0.708460</td>\n",
|
|||
|
" <td>-0.24034</td>\n",
|
|||
|
" <td>0.233664</td>\n",
|
|||
|
" <td>-0.304540</td>\n",
|
|||
|
" <td>-0.432966</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>102005</th>\n",
|
|||
|
" <td>-0.490179</td>\n",
|
|||
|
" <td>-0.513985</td>\n",
|
|||
|
" <td>-0.015247</td>\n",
|
|||
|
" <td>-0.24034</td>\n",
|
|||
|
" <td>1.358222</td>\n",
|
|||
|
" <td>-0.006656</td>\n",
|
|||
|
" <td>-0.674460</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>183791</th>\n",
|
|||
|
" <td>0.699048</td>\n",
|
|||
|
" <td>0.103124</td>\n",
|
|||
|
" <td>-0.015247</td>\n",
|
|||
|
" <td>-0.24034</td>\n",
|
|||
|
" <td>0.421091</td>\n",
|
|||
|
" <td>-0.091564</td>\n",
|
|||
|
" <td>-0.320678</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>230656</th>\n",
|
|||
|
" <td>-0.490179</td>\n",
|
|||
|
" <td>-0.513985</td>\n",
|
|||
|
" <td>0.677965</td>\n",
|
|||
|
" <td>-0.24034</td>\n",
|
|||
|
" <td>-1.453173</td>\n",
|
|||
|
" <td>-0.730021</td>\n",
|
|||
|
" <td>-0.049959</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>93877</th>\n",
|
|||
|
" <td>-0.490179</td>\n",
|
|||
|
" <td>0.967076</td>\n",
|
|||
|
" <td>-0.708460</td>\n",
|
|||
|
" <td>-0.24034</td>\n",
|
|||
|
" <td>-0.516041</td>\n",
|
|||
|
" <td>1.397856</td>\n",
|
|||
|
" <td>1.989666</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>117856</th>\n",
|
|||
|
" <td>-0.490179</td>\n",
|
|||
|
" <td>-0.513985</td>\n",
|
|||
|
" <td>0.677965</td>\n",
|
|||
|
" <td>-0.24034</td>\n",
|
|||
|
" <td>0.889656</td>\n",
|
|||
|
" <td>1.610362</td>\n",
|
|||
|
" <td>1.168279</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>41922</th>\n",
|
|||
|
" <td>-0.490179</td>\n",
|
|||
|
" <td>-0.513985</td>\n",
|
|||
|
" <td>-0.015247</td>\n",
|
|||
|
" <td>-0.24034</td>\n",
|
|||
|
" <td>0.702230</td>\n",
|
|||
|
" <td>1.397856</td>\n",
|
|||
|
" <td>1.108290</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>98221</th>\n",
|
|||
|
" <td>0.104435</td>\n",
|
|||
|
" <td>1.954450</td>\n",
|
|||
|
" <td>-1.401672</td>\n",
|
|||
|
" <td>-0.24034</td>\n",
|
|||
|
" <td>-0.047475</td>\n",
|
|||
|
" <td>0.333917</td>\n",
|
|||
|
" <td>0.408418</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>151717</th>\n",
|
|||
|
" <td>-0.252334</td>\n",
|
|||
|
" <td>-0.513985</td>\n",
|
|||
|
" <td>-0.015247</td>\n",
|
|||
|
" <td>-0.24034</td>\n",
|
|||
|
" <td>0.233664</td>\n",
|
|||
|
" <td>-0.687332</td>\n",
|
|||
|
" <td>-0.854427</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>196817 rows × 109 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" PhysicalHealthDays MentalHealthDays SleepHours HadHeartAttack \\\n",
|
|||
|
"6432 -0.490179 0.103124 0.677965 -0.24034 \n",
|
|||
|
"61767 -0.490179 -0.513985 -0.708460 -0.24034 \n",
|
|||
|
"102005 -0.490179 -0.513985 -0.015247 -0.24034 \n",
|
|||
|
"183791 0.699048 0.103124 -0.015247 -0.24034 \n",
|
|||
|
"230656 -0.490179 -0.513985 0.677965 -0.24034 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"93877 -0.490179 0.967076 -0.708460 -0.24034 \n",
|
|||
|
"117856 -0.490179 -0.513985 0.677965 -0.24034 \n",
|
|||
|
"41922 -0.490179 -0.513985 -0.015247 -0.24034 \n",
|
|||
|
"98221 0.104435 1.954450 -1.401672 -0.24034 \n",
|
|||
|
"151717 -0.252334 -0.513985 -0.015247 -0.24034 \n",
|
|||
|
"\n",
|
|||
|
" HeightInMeters WeightInKilograms BMI State_Alaska \\\n",
|
|||
|
"6432 1.639362 -0.304540 -1.051314 0.0 \n",
|
|||
|
"61767 0.233664 -0.304540 -0.432966 0.0 \n",
|
|||
|
"102005 1.358222 -0.006656 -0.674460 0.0 \n",
|
|||
|
"183791 0.421091 -0.091564 -0.320678 0.0 \n",
|
|||
|
"230656 -1.453173 -0.730021 -0.049959 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"93877 -0.516041 1.397856 1.989666 0.0 \n",
|
|||
|
"117856 0.889656 1.610362 1.168279 0.0 \n",
|
|||
|
"41922 0.702230 1.397856 1.108290 0.0 \n",
|
|||
|
"98221 -0.047475 0.333917 0.408418 0.0 \n",
|
|||
|
"151717 0.233664 -0.687332 -0.854427 0.0 \n",
|
|||
|
"\n",
|
|||
|
" State_Arizona State_Arkansas ... AlcoholDrinkers_Yes \\\n",
|
|||
|
"6432 1.0 0.0 ... 1.0 \n",
|
|||
|
"61767 0.0 0.0 ... 1.0 \n",
|
|||
|
"102005 0.0 0.0 ... 1.0 \n",
|
|||
|
"183791 0.0 0.0 ... 0.0 \n",
|
|||
|
"230656 0.0 0.0 ... 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"93877 0.0 0.0 ... 0.0 \n",
|
|||
|
"117856 0.0 0.0 ... 0.0 \n",
|
|||
|
"41922 0.0 0.0 ... 1.0 \n",
|
|||
|
"98221 0.0 0.0 ... 1.0 \n",
|
|||
|
"151717 0.0 0.0 ... 1.0 \n",
|
|||
|
"\n",
|
|||
|
" HIVTesting_Yes FluVaxLast12_Yes PneumoVaxEver_Yes \\\n",
|
|||
|
"6432 1.0 1.0 0.0 \n",
|
|||
|
"61767 1.0 0.0 0.0 \n",
|
|||
|
"102005 0.0 0.0 0.0 \n",
|
|||
|
"183791 1.0 0.0 0.0 \n",
|
|||
|
"230656 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"93877 0.0 0.0 0.0 \n",
|
|||
|
"117856 0.0 0.0 0.0 \n",
|
|||
|
"41922 0.0 1.0 0.0 \n",
|
|||
|
"98221 1.0 0.0 1.0 \n",
|
|||
|
"151717 1.0 1.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" TetanusLast10Tdap_Yes, received Tdap \\\n",
|
|||
|
"6432 0.0 \n",
|
|||
|
"61767 1.0 \n",
|
|||
|
"102005 1.0 \n",
|
|||
|
"183791 0.0 \n",
|
|||
|
"230656 0.0 \n",
|
|||
|
"... ... \n",
|
|||
|
"93877 1.0 \n",
|
|||
|
"117856 0.0 \n",
|
|||
|
"41922 0.0 \n",
|
|||
|
"98221 1.0 \n",
|
|||
|
"151717 1.0 \n",
|
|||
|
"\n",
|
|||
|
" TetanusLast10Tdap_Yes, received tetanus shot but not sure what type \\\n",
|
|||
|
"6432 0.0 \n",
|
|||
|
"61767 0.0 \n",
|
|||
|
"102005 0.0 \n",
|
|||
|
"183791 1.0 \n",
|
|||
|
"230656 0.0 \n",
|
|||
|
"... ... \n",
|
|||
|
"93877 0.0 \n",
|
|||
|
"117856 1.0 \n",
|
|||
|
"41922 1.0 \n",
|
|||
|
"98221 0.0 \n",
|
|||
|
"151717 0.0 \n",
|
|||
|
"\n",
|
|||
|
" TetanusLast10Tdap_Yes, received tetanus shot, but not Tdap \\\n",
|
|||
|
"6432 0.0 \n",
|
|||
|
"61767 0.0 \n",
|
|||
|
"102005 0.0 \n",
|
|||
|
"183791 0.0 \n",
|
|||
|
"230656 0.0 \n",
|
|||
|
"... ... \n",
|
|||
|
"93877 0.0 \n",
|
|||
|
"117856 0.0 \n",
|
|||
|
"41922 0.0 \n",
|
|||
|
"98221 0.0 \n",
|
|||
|
"151717 0.0 \n",
|
|||
|
"\n",
|
|||
|
" HighRiskLastYear_Yes \\\n",
|
|||
|
"6432 0.0 \n",
|
|||
|
"61767 0.0 \n",
|
|||
|
"102005 0.0 \n",
|
|||
|
"183791 0.0 \n",
|
|||
|
"230656 0.0 \n",
|
|||
|
"... ... \n",
|
|||
|
"93877 0.0 \n",
|
|||
|
"117856 0.0 \n",
|
|||
|
"41922 0.0 \n",
|
|||
|
"98221 0.0 \n",
|
|||
|
"151717 0.0 \n",
|
|||
|
"\n",
|
|||
|
" CovidPos_Tested positive using home test without a health professional \\\n",
|
|||
|
"6432 0.0 \n",
|
|||
|
"61767 0.0 \n",
|
|||
|
"102005 0.0 \n",
|
|||
|
"183791 0.0 \n",
|
|||
|
"230656 0.0 \n",
|
|||
|
"... ... \n",
|
|||
|
"93877 0.0 \n",
|
|||
|
"117856 0.0 \n",
|
|||
|
"41922 0.0 \n",
|
|||
|
"98221 0.0 \n",
|
|||
|
"151717 0.0 \n",
|
|||
|
"\n",
|
|||
|
" CovidPos_Yes \n",
|
|||
|
"6432 1.0 \n",
|
|||
|
"61767 0.0 \n",
|
|||
|
"102005 1.0 \n",
|
|||
|
"183791 0.0 \n",
|
|||
|
"230656 0.0 \n",
|
|||
|
"... ... \n",
|
|||
|
"93877 1.0 \n",
|
|||
|
"117856 1.0 \n",
|
|||
|
"41922 0.0 \n",
|
|||
|
"98221 0.0 \n",
|
|||
|
"151717 1.0 \n",
|
|||
|
"\n",
|
|||
|
"[196817 rows x 109 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 57,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.discriminant_analysis import StandardScaler\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.preprocessing import OneHotEncoder\n",
|
|||
|
"\n",
|
|||
|
"columns_to_drop = ['AgeCategory', 'Sex']\n",
|
|||
|
"num_columns = [\n",
|
|||
|
" column\n",
|
|||
|
" for column in df.columns\n",
|
|||
|
" if column not in columns_to_drop and df[column].dtype != \"object\"\n",
|
|||
|
"]\n",
|
|||
|
"cat_columns = [\n",
|
|||
|
" column\n",
|
|||
|
" for column in df.columns\n",
|
|||
|
" if column not in columns_to_drop and df[column].dtype == \"object\"\n",
|
|||
|
"]\n",
|
|||
|
"\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\"\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"drop_columns = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\",\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"pipeline_end = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" (\"drop_columns\", drop_columns),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"preprocessed_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Формируем набор моделей"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 58,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"class_models = {\n",
|
|||
|
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
|
|||
|
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
|
|||
|
" \"decision_tree\": {\n",
|
|||
|
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=9)\n",
|
|||
|
" },\n",
|
|||
|
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
|
|||
|
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
|
|||
|
" \"gradient_boosting\": {\n",
|
|||
|
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
|
|||
|
" },\n",
|
|||
|
" \"random_forest\": {\n",
|
|||
|
" \"model\": ensemble.RandomForestClassifier(\n",
|
|||
|
" max_depth=11, class_weight=\"balanced\", random_state=9\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" \"mlp\": {\n",
|
|||
|
" \"model\": neural_network.MLPClassifier(\n",
|
|||
|
" hidden_layer_sizes=(7,),\n",
|
|||
|
" max_iter=500,\n",
|
|||
|
" early_stopping=True,\n",
|
|||
|
" random_state=9,\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
"}"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обучаем модели и тестируем их"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 59,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: logistic\n",
|
|||
|
"Model: ridge\n",
|
|||
|
"Model: decision_tree\n",
|
|||
|
"Model: knn\n",
|
|||
|
"Model: naive_bayes\n",
|
|||
|
"Model: gradient_boosting\n",
|
|||
|
"Model: random_forest\n",
|
|||
|
"Model: mlp\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"\n",
|
|||
|
"for model_name in class_models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" model = class_models[model_name][\"model\"]\n",
|
|||
|
"\n",
|
|||
|
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
|
|||
|
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
|
|||
|
"\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train)\n",
|
|||
|
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
|
|||
|
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
|
|||
|
"\n",
|
|||
|
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
|||
|
" class_models[model_name][\"probs\"] = y_test_probs\n",
|
|||
|
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
|||
|
"\n",
|
|||
|
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
|
|||
|
" y_test, y_test_probs\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict, average=None)\n",
|
|||
|
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict, average=None)\n",
|
|||
|
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Матрица неточностей"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 60,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7QAAAQ9CAYAAABp3wEwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxN6eMH8M+9t31XqaS0SEgIGdPYiezrfLPLPgzGMtbvWIqxjUGWwQySMUaYsX0xGSQzg2EsGUuyyxISlaLt3vP7w6/DdetUVLfl8369zmum8zznOc895NNzes5zZIIgCCAiIiIiIiIqZeTa7gARERERERHR++CAloiIiIiIiEolDmiJiIiIiIioVOKAloiIiIiIiEolDmiJiIiIiIioVOKAloiIiIiIiEolDmiJiIiIiIioVOKAloiIiIiIiEolDmiJiIiIiIioVOKAlqgMCA0NhUwmw507d4qk/Tt37kAmkyE0NLRQ2ouMjIRMJkNkZGShtEdERFRWBAYGQiaT5auuTCZDYGBg0XaIqITjgJaIiszq1asLbRBMRERERPQuHW13gIhKPicnJ7x69Qq6uroFOm716tWwtrbGoEGD1PY3a9YMr169gp6eXiH2koiIqPSbMWMGpk2bpu1uEJUaHNASUZ5kMhkMDAwKrT25XF6o7REREZUFqampMDY2ho4Of0Qnyi9OOSYqo1avXo1atWpBX18f9vb2GD16NBITEzXqfffdd3B1dYWhoSE++ugj/Pnnn2jRogVatGgh1snpGdpHjx5h8ODBcHBwgL6+PipVqoSuXbuKz/E6Ozvj8uXLOHbsGGQyGWQymdhmbs/Qnjp1Ch06dECFChVgbGyMOnXqYPny5YV7YYiIiEqA7Gdlr1y5gr59+6JChQpo0qRJjs/QpqenY8KECahYsSJMTU3RpUsX3L9/P8d2IyMj4e3tDQMDA1StWhXff/99rs/l/vTTT2jQoAEMDQ1haWmJ3r174969e0XyeYmKCm//EJVBgYGBCAoKgq+vL0aNGoWYmBisWbMG//zzD44fPy5OHV6zZg3GjBmDpk2bYsKECbhz5w66deuGChUqwMHBQfIcPXv2xOXLlzF27Fg4OzvjyZMnOHToEGJjY+Hs7Izg4GCMHTsWJiYm+OqrrwAAtra2ubZ36NAhdOrUCZUqVcK4ceNgZ2eH6Oho7Nu3D+PGjSu8i0NERFSC/Oc//0G1atUwf/58CIKAJ0+eaNQZNmwYfvrpJ/Tt2xeffPIJIiIi0LFjR41658+fR7t27VCpUiUEBQVBqVRizpw5qFixokbdefPmYebMmfD398ewYcMQHx+PlStXolmzZjh//jwsLCyK4uMSFT6BiEq9jRs3CgCE27dvC0+ePBH09PSEtm3bCkqlUqyzatUqAYAQEhIiCIIgpKenC1ZWVkLDhg2FzMxMsV5oaKgAQGjevLm47/bt2wIAYePGjYIgCMLz588FAMLixYsl+1WrVi21drIdPXpUACAcPXpUEARByMrKElxcXAQnJyfh+fPnanVVKlX+LwQREVEpMXv2bAGA0KdPnxz3Z4uKihIACJ9//rlavb59+woAhNmzZ4v7OnfuLBgZGQkPHjwQ912/fl3Q0dFRa/POnTuCQqEQ5s2bp9bmxYsXBR0dHY39RCUZpxwTlTGHDx9GRkYGxo8fD7n8zbf48OHDYWZmhv379wMAzpw5g4SEBAwfPlztWZ1+/fqhQoUKkucwNDSEnp4eIiMj8fz58w/u8/nz53H79m2MHz9e445wfl9dQEREVBqNHDlSsvzAgQMAgC+++EJt//jx49W+ViqVOHz4MLp16wZ7e3txv5ubG9q3b69Wd+fOnVCpVPD398fTp0/Fzc7ODtWqVcPRo0c/4BMRFS9OOSYqY+7evQsAqF69utp+PT09uLq6iuXZ/3Vzc1Orp6OjA2dnZ8lz6OvrY9GiRfjyyy9ha2uLjz/+GJ06dcLAgQNhZ2dX4D7fvHkTAODp6VngY4mIiEozFxcXyfK7d+9CLpejatWqavvfzfknT57g1atXGrkOaGb99evXIQgCqlWrluM5C/pWAyJt4oCWiN7L+PHj0blzZ+zevRsHDx7EzJkzsWDBAkRERKBevXra7h4REVGpYGhoWOznVKlUkMlk+O2336BQKDTKTUxMir1PRO+LU46JyhgnJycAQExMjNr+jIwM3L59WyzP/u+NGzfU6mVlZYkrFeelatWq+PLLL/H777/j0qVLyMjIwJIlS8Ty/E4Xzr7rfOnSpXzVJyIiKi+cnJygUqnE2UzZ3s15GxsbGBgYaOQ6oJn1VatWhSAIcHFxga+vr8b28ccfF/4HISoiHNASlTG+vr7Q09PDihUrIAiCuH/Dhg1ISkoSV0X09vaGlZUV1q1bh6ysLLHeli1b8nwu9uXLl0hLS1PbV7VqVZiamiI9PV3cZ2xsnOOrgt5Vv359uLi4IDg4WKP+25+BiIiovMl+/nXFihVq+4ODg9W+VigU8PX1xe7du/Hw4UNx/40bN/Dbb7+p1e3RowcUCgWCgoI0clYQBCQkJBTiJyAqWpxyTFTGVKxYEdOnT0dQUBDatWuHLl26ICYmBqtXr0bDhg3Rv39/AK+fqQ0MDMTYsWPRqlUr+Pv7486dOwgNDUXVqlUlf7t67do1tG7dGv7+/vDw8ICOjg527dqFx48fo3fv3mK9Bg0aYM2aNfj666/h5uYGGxsbtGrVSqM9uVyONWvWoHPnzvDy8sLgwYNRqVIlXL16FZcvX8bBgwcL/0IRERGVAl5eXujTpw9Wr16NpKQkfPLJJzhy5EiOv4kNDAzE77//jsaNG2PUqFFQKpVYtWoVPD09ERUVJdarWrUqvv76a0yfPl18ZZ+pqSlu376NXbt2YcSIEZg0aVIxfkqi98cBLVEZFBgYiIoVK2LVqlWYMGECLC0tMWLECMyfP19toYcxY8ZAEAQsWbIEkyZNQt26dbF371588cUXMDAwyLV9R0dH9OnTB0eOHMHmzZuho6ODGjVqYPv27ejZs6dYb9asWbh79y6++eYbvHjxAs2bN89xQAsAfn5+OHr0KIKCgrBkyRKoVCpUrVoVw4cPL7wLQ0REVAqFhISgYsWK2LJlC3bv3o1WrVph//79cHR0VKvXoEED/Pbbb5g0aRJmzpwJR0dHzJkzB9HR0bh69apa3WnTpsHd3R3Lli1DUFAQgNf53rZtW3Tp0qXYPhvRh5IJnM9HRG9RqVSoWLEievTogXXr1mm7O0RERPSBunXrhsuXL+P69eva7gpRoeMztETlWFpamsazMz/++COePXuGFi1aaKdTRERE9N5evXql9vX169dx4MAB5jqVWfwNLVE5FhkZiQkTJuA///kPrKyscO7cOWzYsAE1a9bE2bNnoaenp+0uEhERUQFUqlQJgwYNEt89v2bNGqSnp+P8+fO5vneWqDTjM7RE5ZizszMcHR2xYsUKPHv2DJaWlhg4cCAWLlzIwSwREVEp1K5dO2zduhWPHj2Cvr4+fHx8MH/+fA5mqczib2iJiIiIiIioVOIztERERERERFQqcUBLREREREREpRKfoaVSR6VS4eHDhzA1NYVMJtN2d4iKlSAIePHiBezt7SGXF+49ybS0NGRkZORZT09PT/I9xURU/jCbqTxjNmsXB7RU6jx8+FDjReJE5c29e/fg4OBQaO2lpaXBxckEj54o86xrZ2eH27dvl9vgJCJNzGYiZrO2cEBLpY6pqSkA4O45Z5iZcNa8NnR3r63tLpRbWcjEXzggfh8UloyMDDx6osSNM44wM839+yr5hQpu3veQkZFRLkOTiHLGbNY+ZrP2MJu1iwNaKnWypzKZmcglv7mp6OjIdLXdhfLr/9elL6opfSamMpiY5t62CpxKSESamM3ax2zWImazVnFAS0REokxBiUyJt7llCqpi7A0RERExm6VxQEtERCIVBKiQe2hKlREREVHhYzZL44CWiIhEKghQMjSJiIhKDGazNA5oiYhIlCmokCmRi+V9WhMREVFxYzZL44CWiIhEqv/fpMqJiIio+DCbpXFAS0REImUe05qkyoiIiKj
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1200x1000 with 16 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from matplotlib import pyplot as plt\n",
|
|||
|
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
|||
|
"\n",
|
|||
|
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
|
|||
|
"for index, key in enumerate(class_models.keys()):\n",
|
|||
|
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[f\"No {TARGET_COLUMN_NAME_CLASSIFICATION}\", TARGET_COLUMN_NAME_CLASSIFICATION]\n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
" disp.ax_.set_title(key)\n",
|
|||
|
"\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Точность, полнота, верность (аккуратность), F-мера"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 61,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Precision_train</th>\n",
|
|||
|
" <th>Precision_test</th>\n",
|
|||
|
" <th>Recall_train</th>\n",
|
|||
|
" <th>Recall_test</th>\n",
|
|||
|
" <th>Accuracy_train</th>\n",
|
|||
|
" <th>Accuracy_test</th>\n",
|
|||
|
" <th>F1_train</th>\n",
|
|||
|
" <th>F1_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>logistic</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>ridge</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>decision_tree</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>knn</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.999907</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.999995</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[0.9999973128320332, 0.9999534775529193]</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>naive_bayes</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>gradient_boosting</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>random_forest</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>mlp</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Precision_train Precision_test Recall_train Recall_test \\\n",
|
|||
|
"logistic 1.0 1.0 1.000000 1.0 \n",
|
|||
|
"ridge 1.0 1.0 1.000000 1.0 \n",
|
|||
|
"decision_tree 1.0 1.0 1.000000 1.0 \n",
|
|||
|
"knn 1.0 1.0 0.999907 1.0 \n",
|
|||
|
"naive_bayes 1.0 1.0 1.000000 1.0 \n",
|
|||
|
"gradient_boosting 1.0 1.0 1.000000 1.0 \n",
|
|||
|
"random_forest 1.0 1.0 1.000000 1.0 \n",
|
|||
|
"mlp 1.0 1.0 1.000000 1.0 \n",
|
|||
|
"\n",
|
|||
|
" Accuracy_train Accuracy_test \\\n",
|
|||
|
"logistic 1.000000 1.0 \n",
|
|||
|
"ridge 1.000000 1.0 \n",
|
|||
|
"decision_tree 1.000000 1.0 \n",
|
|||
|
"knn 0.999995 1.0 \n",
|
|||
|
"naive_bayes 1.000000 1.0 \n",
|
|||
|
"gradient_boosting 1.000000 1.0 \n",
|
|||
|
"random_forest 1.000000 1.0 \n",
|
|||
|
"mlp 1.000000 1.0 \n",
|
|||
|
"\n",
|
|||
|
" F1_train F1_test \n",
|
|||
|
"logistic [1.0, 1.0] [1.0, 1.0] \n",
|
|||
|
"ridge [1.0, 1.0] [1.0, 1.0] \n",
|
|||
|
"decision_tree [1.0, 1.0] [1.0, 1.0] \n",
|
|||
|
"knn [0.9999973128320332, 0.9999534775529193] [1.0, 1.0] \n",
|
|||
|
"naive_bayes [1.0, 1.0] [1.0, 1.0] \n",
|
|||
|
"gradient_boosting [1.0, 1.0] [1.0, 1.0] \n",
|
|||
|
"random_forest [1.0, 1.0] [1.0, 1.0] \n",
|
|||
|
"mlp [1.0, 1.0] [1.0, 1.0] "
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 61,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
|||
|
" [\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" \"Accuracy_train\",\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_train\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]\n",
|
|||
|
"class_metrics.sort_values(\n",
|
|||
|
" by=\"Accuracy_test\", ascending=False\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 62,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Accuracy_test</th>\n",
|
|||
|
" <th>F1_test</th>\n",
|
|||
|
" <th>ROC_AUC_test</th>\n",
|
|||
|
" <th>Cohen_kappa_test</th>\n",
|
|||
|
" <th>MCC_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>logistic</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>ridge</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>decision_tree</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>knn</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>naive_bayes</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>gradient_boosting</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>random_forest</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>mlp</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Accuracy_test F1_test ROC_AUC_test Cohen_kappa_test \\\n",
|
|||
|
"logistic 1.0 [1.0, 1.0] 1.0 1.0 \n",
|
|||
|
"ridge 1.0 [1.0, 1.0] 1.0 1.0 \n",
|
|||
|
"decision_tree 1.0 [1.0, 1.0] 1.0 1.0 \n",
|
|||
|
"knn 1.0 [1.0, 1.0] 1.0 1.0 \n",
|
|||
|
"naive_bayes 1.0 [1.0, 1.0] 1.0 1.0 \n",
|
|||
|
"gradient_boosting 1.0 [1.0, 1.0] 1.0 1.0 \n",
|
|||
|
"random_forest 1.0 [1.0, 1.0] 1.0 1.0 \n",
|
|||
|
"mlp 1.0 [1.0, 1.0] 1.0 1.0 \n",
|
|||
|
"\n",
|
|||
|
" MCC_test \n",
|
|||
|
"logistic 1.0 \n",
|
|||
|
"ridge 1.0 \n",
|
|||
|
"decision_tree 1.0 \n",
|
|||
|
"knn 1.0 \n",
|
|||
|
"naive_bayes 1.0 \n",
|
|||
|
"gradient_boosting 1.0 \n",
|
|||
|
"random_forest 1.0 \n",
|
|||
|
"mlp 1.0 "
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 62,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
|||
|
" [\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]\n",
|
|||
|
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Лучшая модель"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 63,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'logistic'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
|
|||
|
"\n",
|
|||
|
"display(best_model)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Находим ошибки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 64,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'Error items count: 0'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>State</th>\n",
|
|||
|
" <th>Predicted</th>\n",
|
|||
|
" <th>Sex</th>\n",
|
|||
|
" <th>GeneralHealth</th>\n",
|
|||
|
" <th>PhysicalHealthDays</th>\n",
|
|||
|
" <th>MentalHealthDays</th>\n",
|
|||
|
" <th>LastCheckupTime</th>\n",
|
|||
|
" <th>PhysicalActivities</th>\n",
|
|||
|
" <th>SleepHours</th>\n",
|
|||
|
" <th>RemovedTeeth</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>HeightInMeters</th>\n",
|
|||
|
" <th>WeightInKilograms</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>AlcoholDrinkers</th>\n",
|
|||
|
" <th>HIVTesting</th>\n",
|
|||
|
" <th>FluVaxLast12</th>\n",
|
|||
|
" <th>PneumoVaxEver</th>\n",
|
|||
|
" <th>TetanusLast10Tdap</th>\n",
|
|||
|
" <th>HighRiskLastYear</th>\n",
|
|||
|
" <th>CovidPos</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>0 rows × 41 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"Empty DataFrame\n",
|
|||
|
"Columns: [State, Predicted, Sex, GeneralHealth, PhysicalHealthDays, MentalHealthDays, LastCheckupTime, PhysicalActivities, SleepHours, RemovedTeeth, HadHeartAttack, HadAngina, HadStroke, HadAsthma, HadSkinCancer, HadCOPD, HadDepressiveDisorder, HadKidneyDisease, HadArthritis, HadDiabetes, DeafOrHardOfHearing, BlindOrVisionDifficulty, DifficultyConcentrating, DifficultyWalking, DifficultyDressingBathing, DifficultyErrands, SmokerStatus, ECigaretteUsage, ChestScan, RaceEthnicityCategory, AgeCategory, HeightInMeters, WeightInKilograms, BMI, AlcoholDrinkers, HIVTesting, FluVaxLast12, PneumoVaxEver, TetanusLast10Tdap, HighRiskLastYear, CovidPos]\n",
|
|||
|
"Index: []\n",
|
|||
|
"\n",
|
|||
|
"[0 rows x 41 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 64,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"preprocessing_result = pipeline_end.transform(X_test)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"y_new_pred = class_models[best_model][\"preds\"]\n",
|
|||
|
"\n",
|
|||
|
"error_index = y_test[y_test[TARGET_COLUMN_NAME_CLASSIFICATION] != y_new_pred].index.tolist()\n",
|
|||
|
"display(f\"Error items count: {len(error_index)}\")\n",
|
|||
|
"\n",
|
|||
|
"error_predicted = pd.Series(y_new_pred, index=y_test.index).loc[error_index]\n",
|
|||
|
"error_df = X_test.loc[error_index].copy()\n",
|
|||
|
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
|
|||
|
"error_df.sort_index()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Пример использования модели (конвейера) для предсказания"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 65,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>State</th>\n",
|
|||
|
" <th>Sex</th>\n",
|
|||
|
" <th>GeneralHealth</th>\n",
|
|||
|
" <th>PhysicalHealthDays</th>\n",
|
|||
|
" <th>MentalHealthDays</th>\n",
|
|||
|
" <th>LastCheckupTime</th>\n",
|
|||
|
" <th>PhysicalActivities</th>\n",
|
|||
|
" <th>SleepHours</th>\n",
|
|||
|
" <th>RemovedTeeth</th>\n",
|
|||
|
" <th>HadHeartAttack</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>HeightInMeters</th>\n",
|
|||
|
" <th>WeightInKilograms</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>AlcoholDrinkers</th>\n",
|
|||
|
" <th>HIVTesting</th>\n",
|
|||
|
" <th>FluVaxLast12</th>\n",
|
|||
|
" <th>PneumoVaxEver</th>\n",
|
|||
|
" <th>TetanusLast10Tdap</th>\n",
|
|||
|
" <th>HighRiskLastYear</th>\n",
|
|||
|
" <th>CovidPos</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>187130</th>\n",
|
|||
|
" <td>South Dakota</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Poor</td>\n",
|
|||
|
" <td>30.0</td>\n",
|
|||
|
" <td>30.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>4.0</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.83</td>\n",
|
|||
|
" <td>97.98</td>\n",
|
|||
|
" <td>29.29</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>1 rows × 40 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" State Sex GeneralHealth PhysicalHealthDays MentalHealthDays \\\n",
|
|||
|
"187130 South Dakota Male Poor 30.0 30.0 \n",
|
|||
|
"\n",
|
|||
|
" LastCheckupTime PhysicalActivities \\\n",
|
|||
|
"187130 Within past year (anytime less than 12 months ... No \n",
|
|||
|
"\n",
|
|||
|
" SleepHours RemovedTeeth HadHeartAttack ... HeightInMeters \\\n",
|
|||
|
"187130 4.0 None of them 0 ... 1.83 \n",
|
|||
|
"\n",
|
|||
|
" WeightInKilograms BMI AlcoholDrinkers HIVTesting FluVaxLast12 \\\n",
|
|||
|
"187130 97.98 29.29 Yes No No \n",
|
|||
|
"\n",
|
|||
|
" PneumoVaxEver TetanusLast10Tdap \\\n",
|
|||
|
"187130 No No, did not receive any tetanus shot in the pa... \n",
|
|||
|
"\n",
|
|||
|
" HighRiskLastYear CovidPos \n",
|
|||
|
"187130 No No \n",
|
|||
|
"\n",
|
|||
|
"[1 rows x 40 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>PhysicalHealthDays</th>\n",
|
|||
|
" <th>MentalHealthDays</th>\n",
|
|||
|
" <th>SleepHours</th>\n",
|
|||
|
" <th>HadHeartAttack</th>\n",
|
|||
|
" <th>HeightInMeters</th>\n",
|
|||
|
" <th>WeightInKilograms</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>State_Alaska</th>\n",
|
|||
|
" <th>State_Arizona</th>\n",
|
|||
|
" <th>State_Arkansas</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>AlcoholDrinkers_Yes</th>\n",
|
|||
|
" <th>HIVTesting_Yes</th>\n",
|
|||
|
" <th>FluVaxLast12_Yes</th>\n",
|
|||
|
" <th>PneumoVaxEver_Yes</th>\n",
|
|||
|
" <th>TetanusLast10Tdap_Yes, received Tdap</th>\n",
|
|||
|
" <th>TetanusLast10Tdap_Yes, received tetanus shot but not sure what type</th>\n",
|
|||
|
" <th>TetanusLast10Tdap_Yes, received tetanus shot, but not Tdap</th>\n",
|
|||
|
" <th>HighRiskLastYear_Yes</th>\n",
|
|||
|
" <th>CovidPos_Tested positive using home test without a health professional</th>\n",
|
|||
|
" <th>CovidPos_Yes</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>187130</th>\n",
|
|||
|
" <td>3.077503</td>\n",
|
|||
|
" <td>3.188668</td>\n",
|
|||
|
" <td>-2.094884</td>\n",
|
|||
|
" <td>-0.24034</td>\n",
|
|||
|
" <td>1.170796</td>\n",
|
|||
|
" <td>0.67449</td>\n",
|
|||
|
" <td>0.096168</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>1 rows × 109 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" PhysicalHealthDays MentalHealthDays SleepHours HadHeartAttack \\\n",
|
|||
|
"187130 3.077503 3.188668 -2.094884 -0.24034 \n",
|
|||
|
"\n",
|
|||
|
" HeightInMeters WeightInKilograms BMI State_Alaska \\\n",
|
|||
|
"187130 1.170796 0.67449 0.096168 0.0 \n",
|
|||
|
"\n",
|
|||
|
" State_Arizona State_Arkansas ... AlcoholDrinkers_Yes \\\n",
|
|||
|
"187130 0.0 0.0 ... 1.0 \n",
|
|||
|
"\n",
|
|||
|
" HIVTesting_Yes FluVaxLast12_Yes PneumoVaxEver_Yes \\\n",
|
|||
|
"187130 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" TetanusLast10Tdap_Yes, received Tdap \\\n",
|
|||
|
"187130 0.0 \n",
|
|||
|
"\n",
|
|||
|
" TetanusLast10Tdap_Yes, received tetanus shot but not sure what type \\\n",
|
|||
|
"187130 0.0 \n",
|
|||
|
"\n",
|
|||
|
" TetanusLast10Tdap_Yes, received tetanus shot, but not Tdap \\\n",
|
|||
|
"187130 0.0 \n",
|
|||
|
"\n",
|
|||
|
" HighRiskLastYear_Yes \\\n",
|
|||
|
"187130 0.0 \n",
|
|||
|
"\n",
|
|||
|
" CovidPos_Tested positive using home test without a health professional \\\n",
|
|||
|
"187130 0.0 \n",
|
|||
|
"\n",
|
|||
|
" CovidPos_Yes \n",
|
|||
|
"187130 0.0 \n",
|
|||
|
"\n",
|
|||
|
"[1 rows x 109 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'predicted: 0 (proba: [9.99540301e-01 4.59698535e-04])'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'real: 0'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"model = class_models[best_model][\"pipeline\"]\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"example_id = 187130\n",
|
|||
|
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
|
|||
|
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
|
|||
|
"display(test)\n",
|
|||
|
"display(test_preprocessed)\n",
|
|||
|
"result_proba = model.predict_proba(test)[0]\n",
|
|||
|
"result = model.predict(test)[0]\n",
|
|||
|
"real = int(y_test.loc[example_id].values[0])\n",
|
|||
|
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
|
|||
|
"display(f\"real: {real}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Создаем гиперпараметры методом поиска по сетке"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 66,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"d:\\code\\AIM-PIbd-31-Potapov-N-S\\lab_4\\.venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
|
|||
|
" _data = np.array(data, dtype=dtype, copy=copy,\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"{'model__criterion': 'gini',\n",
|
|||
|
" 'model__max_depth': 10,\n",
|
|||
|
" 'model__max_features': 'sqrt',\n",
|
|||
|
" 'model__n_estimators': 100}"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 66,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.model_selection import GridSearchCV\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"optimized_model_type = 'random_forest'\n",
|
|||
|
"random_state = 9\n",
|
|||
|
"\n",
|
|||
|
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
|
|||
|
"\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" \"model__n_estimators\": [10, 50, 100],\n",
|
|||
|
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
|
|||
|
" \"model__max_depth\": [5, 7, 10],\n",
|
|||
|
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"gs_optomizer = GridSearchCV(\n",
|
|||
|
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
|
|||
|
")\n",
|
|||
|
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
|
|||
|
"gs_optomizer.best_params_\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обучение модели с новыми гиперпараметрами"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 67,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"optimized_model = ensemble.RandomForestClassifier(\n",
|
|||
|
" random_state=42,\n",
|
|||
|
" criterion=\"gini\",\n",
|
|||
|
" max_depth=5,\n",
|
|||
|
" max_features=\"sqrt\",\n",
|
|||
|
" n_estimators=50,\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"result = {}\n",
|
|||
|
"\n",
|
|||
|
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
|
|||
|
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
|
|||
|
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
|
|||
|
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
|
|||
|
"\n",
|
|||
|
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
|
|||
|
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Формирование данных для оценки старой и новой версии модели и сама оценка данных"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 68,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Precision_train</th>\n",
|
|||
|
" <th>Precision_test</th>\n",
|
|||
|
" <th>Recall_train</th>\n",
|
|||
|
" <th>Recall_test</th>\n",
|
|||
|
" <th>Accuracy_train</th>\n",
|
|||
|
" <th>Accuracy_test</th>\n",
|
|||
|
" <th>F1_train</th>\n",
|
|||
|
" <th>F1_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>Name</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>Old</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>New</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.304987</td>\n",
|
|||
|
" <td>0.298846</td>\n",
|
|||
|
" <td>0.962046</td>\n",
|
|||
|
" <td>0.961711</td>\n",
|
|||
|
" <td>0.467418</td>\n",
|
|||
|
" <td>0.460172</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Precision_train Precision_test Recall_train Recall_test Accuracy_train \\\n",
|
|||
|
"Name \n",
|
|||
|
"Old 1.0 1.0 1.0 1.0 1.0 \n",
|
|||
|
"New 1.0 1.0 0.304987 0.298846 0.962046 \n",
|
|||
|
"\n",
|
|||
|
" Accuracy_test F1_train F1_test \n",
|
|||
|
"Name \n",
|
|||
|
"Old 1.0 [1.0, 1.0] [1.0, 1.0] \n",
|
|||
|
"New 0.961711 0.467418 0.460172 "
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 68,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
|
|||
|
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
|||
|
" data=class_models[optimized_model_type]\n",
|
|||
|
")\n",
|
|||
|
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
|||
|
" data=result\n",
|
|||
|
")\n",
|
|||
|
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
|
|||
|
"optimized_metrics = optimized_metrics.set_index(\"Name\")\n",
|
|||
|
"\n",
|
|||
|
"optimized_metrics[\n",
|
|||
|
" [\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" \"Accuracy_train\",\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_train\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 69,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Accuracy_test</th>\n",
|
|||
|
" <th>F1_test</th>\n",
|
|||
|
" <th>ROC_AUC_test</th>\n",
|
|||
|
" <th>Cohen_kappa_test</th>\n",
|
|||
|
" <th>MCC_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>Name</th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th></th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>Old</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>[1.0, 1.0]</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>New</th>\n",
|
|||
|
" <td>0.961711</td>\n",
|
|||
|
" <td>0.460172</td>\n",
|
|||
|
" <td>0.999994</td>\n",
|
|||
|
" <td>0.446257</td>\n",
|
|||
|
" <td>0.535924</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Accuracy_test F1_test ROC_AUC_test Cohen_kappa_test MCC_test\n",
|
|||
|
"Name \n",
|
|||
|
"Old 1.0 [1.0, 1.0] 1.0 1.0 1.0\n",
|
|||
|
"New 0.961711 0.460172 0.999994 0.446257 0.535924"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 69,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"optimized_metrics[\n",
|
|||
|
" [\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 70,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA9UAAAGsCAYAAADT+IQ/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABsQklEQVR4nO3dd3gU1eLG8XeTkEYahBICgYDUSAlFMSpNkSBYKF6KlFBEQUB6+0lHBOSiFEW8FIOIBlAEBS+IIChFFBClgxRBIPQkJEASdvf3R25W1wTYJBPSvp/nmUeZPXvm7BLyzplz5ozJarVaBQAAAAAAMswppxsAAAAAAEBeRacaAAAAAIBMolMNAAAAAEAm0akGAAAAACCT6FQDAAAAAJBJdKoBAAAAAMgkOtUAAAAAAGQSnWoAAAAAADLJJacbAABAXnDr1i0lJSUZVp+rq6vc3d0Nqw8AgIwg14xDpxoAgHu4deuWypfzUvRFs2F1BgQE6OTJkwX2BAQAkHPINWPRqQYA4B6SkpIUfdGsk7vLycc763dOxV23qHzdP5SUlFQgTz4AADmLXDMWnWoAABzk4+1kyMkHAAC5AblmDDrVAAA4yGy1yGw1ph4AAHIauWYMOtUAADjIIqssyvrZhxF1AACQVeSaMRjrBwAAAAAgkxipBgDAQRZZZMQEN2NqAQAga8g1Y9CpBgDAQWarVWZr1qe4GVEHAABZRa4Zg+nfAAAAAABkEiPVAAA4iAVdAAD5CblmDDrVAAA4yCKrzJx8AADyCXLNGEz/BgAAAAAgkxipBgDAQUyTAwDkJ+SaMRipBgAAAAAgkxipBgDAQTx6BACQn5BrxqBTDQCAgyz/24yoBwCAnEauGYPp3wAAAAAAZBIj1QAAOMhs0KNHjKgDAICsIteMQacaAAAHma0pmxH1AACQ08g1YzD9GwAAAACATGKkGgAAB7GgCwAgPyHXjEGnGgAAB1lkklkmQ+oBACCnkWvGYPo3AAAAAACZxEg1AAAOslhTNiPqAQAgp5FrxmCkGgAAAACATGKkGgAAB5kNuvfMiDoAAMgqcs0YdKoBAHAQJx8AgPyEXDMG078BAAAAAMgkRqoBAHCQxWqSxWrAo0cMqAMAgKwi14xBpxoAAAcxTQ4AkJ+Qa8Zg+jcAAAAAAJnESDUAAA4yy0lmA65Hmw1oCwAAWUWuGYNONQAADrIadO+ZtYDfewYAyB3INWMw/RsAAAAAgExipBoAAAexoAsAID8h14xBpxoAAAeZrU4yWw2498xqQGMAAMgics0YTP8GAAAAACCTGKkGAMBBFplkMeB6tEUF/JI+ACBXINeMwUg1AAAAAACZxEg1AAAOYkEXAEB+Qq4Zg041AAAOMm5Bl4I9TQ4AkDuQa8Zg+jcAAAAAAJnESDUAAA5KWdAl61PcjKgDAICsIteMQacaAAAHWeQkM6ukAgDyCXLNGEz/BgAAAAAgkxipBgDAQSzoAgDIT8g1Y9CpBgDAQRY5ycI0OQBAPkGuGYPp3wAAAAAAZBIj1QAAOMhsNclszfoKp0bUAQBAVpFrxmCkGgAAAACATGKkGgAAB5kNevSIuYDfewYAyB3INWPQqQYAwEEWq5MsBqySaingq6QCAHIHcs0YTP8GACCPmTp1qkwmkwYOHGjbd+vWLfXt21f+/v7y8vJS27ZtdeHCBbv3nT59Wi1btpSnp6dKlCihYcOG6fbt23ZlNm/erDp16sjNzU0VK1ZUZGRkmuO/9957Cg4Olru7u+rXr6+ffvopOz4mAKCAyOu5RqcaAAAHpU6TM2LLrJ9//lkffPCBatasabd/0KBB+uqrr7RixQpt2bJF586dU5s2bf5qu9msli1bKikpSdu3b9fixYsVGRmpsWPH2sqcPHlSLVu2VJMmTbR3714NHDhQL730ktavX28rs2zZMg0ePFjjxo3Tnj17VKtWLYWHh+vixYuZ/kwAgJxBrhmTayartYCP1QMAcA9xcXHy9fXVB3vqysMr63dO3Yy/rVfq7FZsbKx8fHwcfl98fLzq1KmjuXPn6o033lBoaKhmzpyp2NhYFS9eXJ988oleeOEFSdLhw4dVrVo17dixQ4888oj++9//6plnntG5c+dUsmRJSdK8efM0YsQIXbp0Sa6urhoxYoTWrl2r/fv3247ZoUMHxcTEaN26dZKk+vXr66GHHtK7774rSbJYLAoKClL//v01cuTILH83AIDsR64Zm2uMVAMAkEPi4uLstsTExLuW79u3r1q2bKmmTZva7d+9e7eSk5Pt9letWlVly5bVjh07JEk7duxQjRo1bCcekhQeHq64uDgdOHDAVuafdYeHh9vqSEpK0u7du+3KODk5qWnTprYyAICCq6DmGp1qAAAcZJGTYZskBQUFydfX17ZNmTLljseOiorSnj170i0THR0tV1dX+fn52e0vWbKkoqOjbWX+fuKR+nrqa3crExcXp5s3b+ry5csym83plkmtAwCQd5BrxuQaq38DAOAgs9VJZgNWSU2t48yZM3bT5Nzc3NItf+bMGQ0YMEAbNmyQu7t7lo8PAIBErhmFkWoAAHKIj4+P3Xank4/du3fr4sWLqlOnjlxcXOTi4qItW7Zo9uzZcnFxUcmSJZWUlKSYmBi79124cEEBAQGSpICAgDSrpqb++V5lfHx85OHhoWLFisnZ2TndMql1AAAKroKaa3SqAQBwkEUmw7aMePLJJ7Vv3z7t3bvXttWrV0+dOnWy/X+hQoW0ceNG23uOHDmi06dPKywsTJIUFhamffv22a1mumHDBvn4+CgkJMRW5u91pJZJrcPV1VV169a1K2OxWLRx40ZbGQBA3kGuGZNrTP8GAMBBRk+Tc5S3t7eqV69ut69w4cLy9/e37e/Zs6cGDx6sokWLysfHR/3791dYWJgeeeQRSVKzZs0UEhKiLl266K233lJ0dLRGjx6tvn372kYSevfurXfffVfDhw9Xjx49tGnTJi1fvlxr1661HXfw4MGKiIhQvXr19PDDD2vmzJlKSEhQ9+7ds/KVAAByALlmTK7RqQYAIB9455135OTkpLZt2yoxMVHh4eGaO3eu7XVnZ2etWbNGffr0UVhYmAoXLqyIiAhNnDjRVqZ8+fJau3atBg0apFmzZqlMmTJasGCBwsPDbWXat2+vS5cuaezYsYqOjlZoaKjWrVuXZpEXAACyIi/lGs+pBgDgHlKf5/nvXY8b9jzPofW2Zvh5ngAAGIFcMxb3VAMAAAAAkElM/0aeY7FYdO7cOXl7e8tkytiiCAAKHqvVquvXryswMFBOTlm7lmyxmmSxZv33jhF1IP8g1wBkBLmW+9CpRp5z7tw5BQUF5XQzAOQxZ86cUZkyZbJUh0VOMhswycvCRDH8DbkGIDPItdyDTjXyHG9vb0nSH3uC5eNVsP8BI63WlWvkdBOQy9xWsrbqa9vvDiC3IddwN+Qa/olcy33oVCPPSZ0a5+PlJB9vTj5gz8VUKKebgNzmf8txGjGt1mJ1ksWAR48YUQfyD3INd0OuIQ1yLdehUw0AgIPMMsmsrJ/EGFEHAABZRa4Zo2BfUgAAAAAAIAsYqQYAwEFMkwMA5CfkmjHoVAMA4CCzjJniZs56UwAAyDJyzRgF+5ICAAAAAABZwEg1AAAOYpocACA/IdeMUbA/PQAAAAAAWcBINQAADjJbnWQ24Gq8EXUAAJBV5Jox6FQDAOAgq0yyGLCgi7WAP88TAJA7kGvGKNiXFAAAAAAAyAJGqgEAcBDT5AAA+Qm5Zgw61QAAOMhiNclizfoUNyPqAAAgq8g1YxTsSwoAAAAAAGQBI9UAADjILCeZDbgebUQdAABkFblmDDrVAAA4iGlyAID8hFwzRsG+pAAAAAAAQBYwUg0AgIMscpLFgOvRRtQBAEBWkWvGoFMNAICDzFaTzAZMcTOiDgAAsopcM0bBvqQAAAAAAEAWMFINAICDWNAFAJCfkGvGYKQaAAAAAIBMYqQaAAAHWa1Oslizfj3aakAdAABkFblmDDrVAAA4yCyTzDJgQRcD6gAAIKvINWMU7EsKAAAAAABkASPVAAA4yGI1ZjEWi9W
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1000x400 with 4 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"for index in range(0, len(optimized_metrics)):\n",
|
|||
|
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[f\"No {TARGET_COLUMN_NAME_CLASSIFICATION}\", TARGET_COLUMN_NAME_CLASSIFICATION]\n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
"\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Модель хорошо классифицировала объекты, которые относятся к \"No HadHeartAttack\" и \"HadHeartAttack\""
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Регрессия"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Бизнес цель 2: \n",
|
|||
|
"Предсказание среднего количества часов сна в день (SleepTime) на основе других факторов."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Формируем выборки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 85,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>State</th>\n",
|
|||
|
" <th>Sex</th>\n",
|
|||
|
" <th>GeneralHealth</th>\n",
|
|||
|
" <th>PhysicalHealthDays</th>\n",
|
|||
|
" <th>MentalHealthDays</th>\n",
|
|||
|
" <th>LastCheckupTime</th>\n",
|
|||
|
" <th>PhysicalActivities</th>\n",
|
|||
|
" <th>RemovedTeeth</th>\n",
|
|||
|
" <th>HadHeartAttack</th>\n",
|
|||
|
" <th>HadAngina</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>HeightInMeters</th>\n",
|
|||
|
" <th>WeightInKilograms</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>AlcoholDrinkers</th>\n",
|
|||
|
" <th>HIVTesting</th>\n",
|
|||
|
" <th>FluVaxLast12</th>\n",
|
|||
|
" <th>PneumoVaxEver</th>\n",
|
|||
|
" <th>TetanusLast10Tdap</th>\n",
|
|||
|
" <th>HighRiskLastYear</th>\n",
|
|||
|
" <th>CovidPos</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>108769</th>\n",
|
|||
|
" <td>Minnesota</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Good</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.73</td>\n",
|
|||
|
" <td>83.91</td>\n",
|
|||
|
" <td>28.13</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes, received tetanus shot but not sure what type</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>240750</th>\n",
|
|||
|
" <td>Guam</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Excellent</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past 2 years (1 year but less than 2 ye...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.65</td>\n",
|
|||
|
" <td>70.31</td>\n",
|
|||
|
" <td>25.79</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>100329</th>\n",
|
|||
|
" <td>Michigan</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Excellent</td>\n",
|
|||
|
" <td>3.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>1 to 5</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.60</td>\n",
|
|||
|
" <td>58.97</td>\n",
|
|||
|
" <td>23.03</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes, received tetanus shot but not sure what type</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>132628</th>\n",
|
|||
|
" <td>New Hampshire</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Good</td>\n",
|
|||
|
" <td>4.0</td>\n",
|
|||
|
" <td>6.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>1 to 5</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.70</td>\n",
|
|||
|
" <td>68.04</td>\n",
|
|||
|
" <td>23.49</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>72101</th>\n",
|
|||
|
" <td>Kansas</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Very good</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.83</td>\n",
|
|||
|
" <td>99.79</td>\n",
|
|||
|
" <td>29.84</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes, received Tdap</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>119879</th>\n",
|
|||
|
" <td>Missouri</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Excellent</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.78</td>\n",
|
|||
|
" <td>61.23</td>\n",
|
|||
|
" <td>19.37</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received tetanus shot but not sure what type</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>103694</th>\n",
|
|||
|
" <td>Michigan</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Good</td>\n",
|
|||
|
" <td>10.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>1 to 5</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.63</td>\n",
|
|||
|
" <td>74.84</td>\n",
|
|||
|
" <td>28.32</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>131932</th>\n",
|
|||
|
" <td>Nevada</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Good</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>1 to 5</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.70</td>\n",
|
|||
|
" <td>90.72</td>\n",
|
|||
|
" <td>31.32</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>146867</th>\n",
|
|||
|
" <td>New York</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Very good</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>1 to 5</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.68</td>\n",
|
|||
|
" <td>77.11</td>\n",
|
|||
|
" <td>27.44</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>121958</th>\n",
|
|||
|
" <td>Montana</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Good</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>All</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.65</td>\n",
|
|||
|
" <td>98.88</td>\n",
|
|||
|
" <td>36.28</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes, received tetanus shot but not sure what type</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>196817 rows × 39 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" State Sex GeneralHealth PhysicalHealthDays \\\n",
|
|||
|
"108769 Minnesota Male Good 0.0 \n",
|
|||
|
"240750 Guam Male Excellent 0.0 \n",
|
|||
|
"100329 Michigan Female Excellent 3.0 \n",
|
|||
|
"132628 New Hampshire Male Good 4.0 \n",
|
|||
|
"72101 Kansas Male Very good 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"119879 Missouri Female Excellent 0.0 \n",
|
|||
|
"103694 Michigan Female Good 10.0 \n",
|
|||
|
"131932 Nevada Female Good 0.0 \n",
|
|||
|
"146867 New York Female Very good 0.0 \n",
|
|||
|
"121958 Montana Female Good 1.0 \n",
|
|||
|
"\n",
|
|||
|
" MentalHealthDays LastCheckupTime \\\n",
|
|||
|
"108769 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"240750 0.0 Within past 2 years (1 year but less than 2 ye... \n",
|
|||
|
"100329 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"132628 6.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"72101 2.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"... ... ... \n",
|
|||
|
"119879 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"103694 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"131932 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"146867 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"121958 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"\n",
|
|||
|
" PhysicalActivities RemovedTeeth HadHeartAttack HadAngina ... \\\n",
|
|||
|
"108769 Yes None of them No No ... \n",
|
|||
|
"240750 Yes None of them No Yes ... \n",
|
|||
|
"100329 No 1 to 5 No Yes ... \n",
|
|||
|
"132628 Yes 1 to 5 No Yes ... \n",
|
|||
|
"72101 Yes None of them No No ... \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"119879 Yes None of them No No ... \n",
|
|||
|
"103694 Yes 1 to 5 No No ... \n",
|
|||
|
"131932 Yes 1 to 5 No No ... \n",
|
|||
|
"146867 Yes 1 to 5 No No ... \n",
|
|||
|
"121958 Yes All No No ... \n",
|
|||
|
"\n",
|
|||
|
" HeightInMeters WeightInKilograms BMI AlcoholDrinkers HIVTesting \\\n",
|
|||
|
"108769 1.73 83.91 28.13 No Yes \n",
|
|||
|
"240750 1.65 70.31 25.79 Yes No \n",
|
|||
|
"100329 1.60 58.97 23.03 No No \n",
|
|||
|
"132628 1.70 68.04 23.49 Yes No \n",
|
|||
|
"72101 1.83 99.79 29.84 Yes No \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"119879 1.78 61.23 19.37 Yes Yes \n",
|
|||
|
"103694 1.63 74.84 28.32 Yes No \n",
|
|||
|
"131932 1.70 90.72 31.32 No No \n",
|
|||
|
"146867 1.68 77.11 27.44 Yes No \n",
|
|||
|
"121958 1.65 98.88 36.28 Yes Yes \n",
|
|||
|
"\n",
|
|||
|
" FluVaxLast12 PneumoVaxEver \\\n",
|
|||
|
"108769 Yes Yes \n",
|
|||
|
"240750 No No \n",
|
|||
|
"100329 Yes Yes \n",
|
|||
|
"132628 Yes No \n",
|
|||
|
"72101 Yes Yes \n",
|
|||
|
"... ... ... \n",
|
|||
|
"119879 Yes No \n",
|
|||
|
"103694 Yes Yes \n",
|
|||
|
"131932 No No \n",
|
|||
|
"146867 Yes No \n",
|
|||
|
"121958 Yes Yes \n",
|
|||
|
"\n",
|
|||
|
" TetanusLast10Tdap HighRiskLastYear \\\n",
|
|||
|
"108769 Yes, received tetanus shot but not sure what type Yes \n",
|
|||
|
"240750 No, did not receive any tetanus shot in the pa... No \n",
|
|||
|
"100329 Yes, received tetanus shot but not sure what type No \n",
|
|||
|
"132628 No, did not receive any tetanus shot in the pa... No \n",
|
|||
|
"72101 Yes, received Tdap No \n",
|
|||
|
"... ... ... \n",
|
|||
|
"119879 Yes, received tetanus shot but not sure what type No \n",
|
|||
|
"103694 No, did not receive any tetanus shot in the pa... No \n",
|
|||
|
"131932 No, did not receive any tetanus shot in the pa... No \n",
|
|||
|
"146867 No, did not receive any tetanus shot in the pa... No \n",
|
|||
|
"121958 Yes, received tetanus shot but not sure what type No \n",
|
|||
|
"\n",
|
|||
|
" CovidPos \n",
|
|||
|
"108769 Yes \n",
|
|||
|
"240750 Yes \n",
|
|||
|
"100329 No \n",
|
|||
|
"132628 No \n",
|
|||
|
"72101 No \n",
|
|||
|
"... ... \n",
|
|||
|
"119879 No \n",
|
|||
|
"103694 No \n",
|
|||
|
"131932 No \n",
|
|||
|
"146867 Yes \n",
|
|||
|
"121958 No \n",
|
|||
|
"\n",
|
|||
|
"[196817 rows x 39 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>SleepHours</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>108769</th>\n",
|
|||
|
" <td>6.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>240750</th>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>100329</th>\n",
|
|||
|
" <td>9.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>132628</th>\n",
|
|||
|
" <td>6.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>72101</th>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>119879</th>\n",
|
|||
|
" <td>8.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>103694</th>\n",
|
|||
|
" <td>8.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>131932</th>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>146867</th>\n",
|
|||
|
" <td>8.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>121958</th>\n",
|
|||
|
" <td>8.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>196817 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" SleepHours\n",
|
|||
|
"108769 6.0\n",
|
|||
|
"240750 7.0\n",
|
|||
|
"100329 9.0\n",
|
|||
|
"132628 6.0\n",
|
|||
|
"72101 7.0\n",
|
|||
|
"... ...\n",
|
|||
|
"119879 8.0\n",
|
|||
|
"103694 8.0\n",
|
|||
|
"131932 7.0\n",
|
|||
|
"146867 8.0\n",
|
|||
|
"121958 8.0\n",
|
|||
|
"\n",
|
|||
|
"[196817 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>State</th>\n",
|
|||
|
" <th>Sex</th>\n",
|
|||
|
" <th>GeneralHealth</th>\n",
|
|||
|
" <th>PhysicalHealthDays</th>\n",
|
|||
|
" <th>MentalHealthDays</th>\n",
|
|||
|
" <th>LastCheckupTime</th>\n",
|
|||
|
" <th>PhysicalActivities</th>\n",
|
|||
|
" <th>RemovedTeeth</th>\n",
|
|||
|
" <th>HadHeartAttack</th>\n",
|
|||
|
" <th>HadAngina</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>HeightInMeters</th>\n",
|
|||
|
" <th>WeightInKilograms</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>AlcoholDrinkers</th>\n",
|
|||
|
" <th>HIVTesting</th>\n",
|
|||
|
" <th>FluVaxLast12</th>\n",
|
|||
|
" <th>PneumoVaxEver</th>\n",
|
|||
|
" <th>TetanusLast10Tdap</th>\n",
|
|||
|
" <th>HighRiskLastYear</th>\n",
|
|||
|
" <th>CovidPos</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>194767</th>\n",
|
|||
|
" <td>Texas</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Good</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.68</td>\n",
|
|||
|
" <td>113.40</td>\n",
|
|||
|
" <td>40.35</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>231923</th>\n",
|
|||
|
" <td>Wisconsin</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Good</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>1 to 5</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.73</td>\n",
|
|||
|
" <td>104.33</td>\n",
|
|||
|
" <td>34.97</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>52815</th>\n",
|
|||
|
" <td>Idaho</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Poor</td>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" <td>10.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>1 to 5</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.73</td>\n",
|
|||
|
" <td>104.33</td>\n",
|
|||
|
" <td>34.97</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes, received tetanus shot but not sure what type</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>65909</th>\n",
|
|||
|
" <td>Iowa</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Good</td>\n",
|
|||
|
" <td>20.0</td>\n",
|
|||
|
" <td>10.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>All</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.68</td>\n",
|
|||
|
" <td>127.01</td>\n",
|
|||
|
" <td>45.19</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>184154</th>\n",
|
|||
|
" <td>South Dakota</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Excellent</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.60</td>\n",
|
|||
|
" <td>49.90</td>\n",
|
|||
|
" <td>19.49</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received Tdap</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Tested positive using home test without a heal...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>57503</th>\n",
|
|||
|
" <td>Indiana</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Fair</td>\n",
|
|||
|
" <td>3.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>1 to 5</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.63</td>\n",
|
|||
|
" <td>97.52</td>\n",
|
|||
|
" <td>36.90</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes, received Tdap</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>47420</th>\n",
|
|||
|
" <td>Hawaii</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Fair</td>\n",
|
|||
|
" <td>30.0</td>\n",
|
|||
|
" <td>5.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.70</td>\n",
|
|||
|
" <td>77.56</td>\n",
|
|||
|
" <td>26.78</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received tetanus shot but not sure what type</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>186088</th>\n",
|
|||
|
" <td>South Dakota</td>\n",
|
|||
|
" <td>Female</td>\n",
|
|||
|
" <td>Good</td>\n",
|
|||
|
" <td>15.0</td>\n",
|
|||
|
" <td>15.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>1 to 5</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.73</td>\n",
|
|||
|
" <td>54.88</td>\n",
|
|||
|
" <td>18.40</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received tetanus shot but not sure what type</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>11687</th>\n",
|
|||
|
" <td>Arkansas</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Excellent</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.78</td>\n",
|
|||
|
" <td>88.45</td>\n",
|
|||
|
" <td>27.98</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes, received tetanus shot but not sure what type</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>200835</th>\n",
|
|||
|
" <td>Utah</td>\n",
|
|||
|
" <td>Male</td>\n",
|
|||
|
" <td>Very good</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>Within past year (anytime less than 12 months ...</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>None of them</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.91</td>\n",
|
|||
|
" <td>118.39</td>\n",
|
|||
|
" <td>32.62</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" <td>Yes, received Tdap</td>\n",
|
|||
|
" <td>No</td>\n",
|
|||
|
" <td>Yes</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>49205 rows × 39 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" State Sex GeneralHealth PhysicalHealthDays \\\n",
|
|||
|
"194767 Texas Female Good 0.0 \n",
|
|||
|
"231923 Wisconsin Female Good 2.0 \n",
|
|||
|
"52815 Idaho Male Poor 7.0 \n",
|
|||
|
"65909 Iowa Female Good 20.0 \n",
|
|||
|
"184154 South Dakota Female Excellent 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"57503 Indiana Female Fair 3.0 \n",
|
|||
|
"47420 Hawaii Female Fair 30.0 \n",
|
|||
|
"186088 South Dakota Female Good 15.0 \n",
|
|||
|
"11687 Arkansas Male Excellent 0.0 \n",
|
|||
|
"200835 Utah Male Very good 0.0 \n",
|
|||
|
"\n",
|
|||
|
" MentalHealthDays LastCheckupTime \\\n",
|
|||
|
"194767 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"231923 5.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"52815 10.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"65909 10.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"184154 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"... ... ... \n",
|
|||
|
"57503 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"47420 5.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"186088 15.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"11687 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"200835 0.0 Within past year (anytime less than 12 months ... \n",
|
|||
|
"\n",
|
|||
|
" PhysicalActivities RemovedTeeth HadHeartAttack HadAngina ... \\\n",
|
|||
|
"194767 Yes None of them No No ... \n",
|
|||
|
"231923 Yes 1 to 5 No No ... \n",
|
|||
|
"52815 Yes 1 to 5 No Yes ... \n",
|
|||
|
"65909 No All Yes No ... \n",
|
|||
|
"184154 Yes None of them No No ... \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"57503 Yes 1 to 5 Yes No ... \n",
|
|||
|
"47420 Yes None of them No No ... \n",
|
|||
|
"186088 Yes 1 to 5 No No ... \n",
|
|||
|
"11687 Yes None of them No No ... \n",
|
|||
|
"200835 Yes None of them No No ... \n",
|
|||
|
"\n",
|
|||
|
" HeightInMeters WeightInKilograms BMI AlcoholDrinkers HIVTesting \\\n",
|
|||
|
"194767 1.68 113.40 40.35 No No \n",
|
|||
|
"231923 1.73 104.33 34.97 Yes Yes \n",
|
|||
|
"52815 1.73 104.33 34.97 No No \n",
|
|||
|
"65909 1.68 127.01 45.19 No No \n",
|
|||
|
"184154 1.60 49.90 19.49 Yes No \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"57503 1.63 97.52 36.90 Yes Yes \n",
|
|||
|
"47420 1.70 77.56 26.78 No Yes \n",
|
|||
|
"186088 1.73 54.88 18.40 Yes No \n",
|
|||
|
"11687 1.78 88.45 27.98 Yes Yes \n",
|
|||
|
"200835 1.91 118.39 32.62 No Yes \n",
|
|||
|
"\n",
|
|||
|
" FluVaxLast12 PneumoVaxEver \\\n",
|
|||
|
"194767 No No \n",
|
|||
|
"231923 No Yes \n",
|
|||
|
"52815 Yes Yes \n",
|
|||
|
"65909 No No \n",
|
|||
|
"184154 Yes No \n",
|
|||
|
"... ... ... \n",
|
|||
|
"57503 Yes Yes \n",
|
|||
|
"47420 Yes No \n",
|
|||
|
"186088 Yes No \n",
|
|||
|
"11687 Yes No \n",
|
|||
|
"200835 No Yes \n",
|
|||
|
"\n",
|
|||
|
" TetanusLast10Tdap HighRiskLastYear \\\n",
|
|||
|
"194767 No, did not receive any tetanus shot in the pa... No \n",
|
|||
|
"231923 No, did not receive any tetanus shot in the pa... No \n",
|
|||
|
"52815 Yes, received tetanus shot but not sure what type No \n",
|
|||
|
"65909 No, did not receive any tetanus shot in the pa... No \n",
|
|||
|
"184154 Yes, received Tdap No \n",
|
|||
|
"... ... ... \n",
|
|||
|
"57503 Yes, received Tdap No \n",
|
|||
|
"47420 Yes, received tetanus shot but not sure what type No \n",
|
|||
|
"186088 Yes, received tetanus shot but not sure what type No \n",
|
|||
|
"11687 Yes, received tetanus shot but not sure what type No \n",
|
|||
|
"200835 Yes, received Tdap No \n",
|
|||
|
"\n",
|
|||
|
" CovidPos \n",
|
|||
|
"194767 Yes \n",
|
|||
|
"231923 No \n",
|
|||
|
"52815 No \n",
|
|||
|
"65909 Yes \n",
|
|||
|
"184154 Tested positive using home test without a heal... \n",
|
|||
|
"... ... \n",
|
|||
|
"57503 No \n",
|
|||
|
"47420 No \n",
|
|||
|
"186088 Yes \n",
|
|||
|
"11687 No \n",
|
|||
|
"200835 Yes \n",
|
|||
|
"\n",
|
|||
|
"[49205 rows x 39 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>SleepHours</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>194767</th>\n",
|
|||
|
" <td>8.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>231923</th>\n",
|
|||
|
" <td>8.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>52815</th>\n",
|
|||
|
" <td>6.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>65909</th>\n",
|
|||
|
" <td>8.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>184154</th>\n",
|
|||
|
" <td>7.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>57503</th>\n",
|
|||
|
" <td>6.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>47420</th>\n",
|
|||
|
" <td>6.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>186088</th>\n",
|
|||
|
" <td>6.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>11687</th>\n",
|
|||
|
" <td>8.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>200835</th>\n",
|
|||
|
" <td>8.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>49205 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" SleepHours\n",
|
|||
|
"194767 8.0\n",
|
|||
|
"231923 8.0\n",
|
|||
|
"52815 6.0\n",
|
|||
|
"65909 8.0\n",
|
|||
|
"184154 7.0\n",
|
|||
|
"... ...\n",
|
|||
|
"57503 6.0\n",
|
|||
|
"47420 6.0\n",
|
|||
|
"186088 6.0\n",
|
|||
|
"11687 8.0\n",
|
|||
|
"200835 8.0\n",
|
|||
|
"\n",
|
|||
|
"[49205 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"df = pd.read_csv(\"csv\\\\heart_2022_no_nans.csv\")\n",
|
|||
|
"\n",
|
|||
|
"TARGET_COLUMN_NAME_REGRESSION = \"SleepHours\"\n",
|
|||
|
"\n",
|
|||
|
"def split_into_train_test(\n",
|
|||
|
" df_input: DataFrame,\n",
|
|||
|
" target_colname: str,\n",
|
|||
|
" frac_train: float = 0.8,\n",
|
|||
|
" random_state: int = None,\n",
|
|||
|
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
|||
|
" \n",
|
|||
|
" if not (0 < frac_train < 1):\n",
|
|||
|
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
|
|||
|
" \n",
|
|||
|
" if target_colname not in df_input.columns:\n",
|
|||
|
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
|
|||
|
" \n",
|
|||
|
" X = df_input.drop(columns=[target_colname])\n",
|
|||
|
" y = df_input[[target_colname]]\n",
|
|||
|
"\n",
|
|||
|
" X_train, X_test, y_train, y_test = train_test_split(\n",
|
|||
|
" X, y,\n",
|
|||
|
" test_size=(1.0 - frac_train),\n",
|
|||
|
" random_state=random_state\n",
|
|||
|
" )\n",
|
|||
|
" return X_train, X_test, y_train, y_test\n",
|
|||
|
"\n",
|
|||
|
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
|
|||
|
" df, \n",
|
|||
|
" target_colname=TARGET_COLUMN_NAME_REGRESSION, \n",
|
|||
|
" frac_train=0.8, \n",
|
|||
|
" random_state=42\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_train\", X_train)\n",
|
|||
|
"display(\"y_train\", y_train)\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_test\", X_test)\n",
|
|||
|
"display(\"y_test\", y_test)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 86,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def get_filtered_columns(df: DataFrame, no_numeric=False, no_text=False) -> list[str]:\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" Возвращает список колонок по фильтру\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" w = []\n",
|
|||
|
" for column in df.columns:\n",
|
|||
|
" if no_numeric and pd.api.types.is_numeric_dtype(df[column]):\n",
|
|||
|
" continue\n",
|
|||
|
" if no_text and not pd.api.types.is_numeric_dtype(df[column]):\n",
|
|||
|
" continue\n",
|
|||
|
" w.append(column)\n",
|
|||
|
" return w"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Выполним one-hot encoding, чтобы избавиться от категориальных признаков"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 87,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>PhysicalHealthDays</th>\n",
|
|||
|
" <th>MentalHealthDays</th>\n",
|
|||
|
" <th>HeightInMeters</th>\n",
|
|||
|
" <th>WeightInKilograms</th>\n",
|
|||
|
" <th>BMI</th>\n",
|
|||
|
" <th>State_Alaska</th>\n",
|
|||
|
" <th>State_Arizona</th>\n",
|
|||
|
" <th>State_Arkansas</th>\n",
|
|||
|
" <th>State_California</th>\n",
|
|||
|
" <th>State_Colorado</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>AlcoholDrinkers_Yes</th>\n",
|
|||
|
" <th>HIVTesting_Yes</th>\n",
|
|||
|
" <th>FluVaxLast12_Yes</th>\n",
|
|||
|
" <th>PneumoVaxEver_Yes</th>\n",
|
|||
|
" <th>TetanusLast10Tdap_Yes, received Tdap</th>\n",
|
|||
|
" <th>TetanusLast10Tdap_Yes, received tetanus shot but not sure what type</th>\n",
|
|||
|
" <th>TetanusLast10Tdap_Yes, received tetanus shot, but not Tdap</th>\n",
|
|||
|
" <th>HighRiskLastYear_Yes</th>\n",
|
|||
|
" <th>CovidPos_Tested positive using home test without a health professional</th>\n",
|
|||
|
" <th>CovidPos_Yes</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>108769</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.73</td>\n",
|
|||
|
" <td>83.91</td>\n",
|
|||
|
" <td>28.13</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>240750</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.65</td>\n",
|
|||
|
" <td>70.31</td>\n",
|
|||
|
" <td>25.79</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>100329</th>\n",
|
|||
|
" <td>3.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.60</td>\n",
|
|||
|
" <td>58.97</td>\n",
|
|||
|
" <td>23.03</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>132628</th>\n",
|
|||
|
" <td>4.0</td>\n",
|
|||
|
" <td>6.0</td>\n",
|
|||
|
" <td>1.70</td>\n",
|
|||
|
" <td>68.04</td>\n",
|
|||
|
" <td>23.49</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>72101</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>2.0</td>\n",
|
|||
|
" <td>1.83</td>\n",
|
|||
|
" <td>99.79</td>\n",
|
|||
|
" <td>29.84</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>119879</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.78</td>\n",
|
|||
|
" <td>61.23</td>\n",
|
|||
|
" <td>19.37</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>103694</th>\n",
|
|||
|
" <td>10.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.63</td>\n",
|
|||
|
" <td>74.84</td>\n",
|
|||
|
" <td>28.32</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>131932</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.70</td>\n",
|
|||
|
" <td>90.72</td>\n",
|
|||
|
" <td>31.32</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>146867</th>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.68</td>\n",
|
|||
|
" <td>77.11</td>\n",
|
|||
|
" <td>27.44</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>121958</th>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.65</td>\n",
|
|||
|
" <td>98.88</td>\n",
|
|||
|
" <td>36.28</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>True</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>196817 rows × 121 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" PhysicalHealthDays MentalHealthDays HeightInMeters \\\n",
|
|||
|
"108769 0.0 0.0 1.73 \n",
|
|||
|
"240750 0.0 0.0 1.65 \n",
|
|||
|
"100329 3.0 0.0 1.60 \n",
|
|||
|
"132628 4.0 6.0 1.70 \n",
|
|||
|
"72101 0.0 2.0 1.83 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"119879 0.0 0.0 1.78 \n",
|
|||
|
"103694 10.0 0.0 1.63 \n",
|
|||
|
"131932 0.0 0.0 1.70 \n",
|
|||
|
"146867 0.0 0.0 1.68 \n",
|
|||
|
"121958 1.0 0.0 1.65 \n",
|
|||
|
"\n",
|
|||
|
" WeightInKilograms BMI State_Alaska State_Arizona State_Arkansas \\\n",
|
|||
|
"108769 83.91 28.13 False False False \n",
|
|||
|
"240750 70.31 25.79 False False False \n",
|
|||
|
"100329 58.97 23.03 False False False \n",
|
|||
|
"132628 68.04 23.49 False False False \n",
|
|||
|
"72101 99.79 29.84 False False False \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"119879 61.23 19.37 False False False \n",
|
|||
|
"103694 74.84 28.32 False False False \n",
|
|||
|
"131932 90.72 31.32 False False False \n",
|
|||
|
"146867 77.11 27.44 False False False \n",
|
|||
|
"121958 98.88 36.28 False False False \n",
|
|||
|
"\n",
|
|||
|
" State_California State_Colorado ... AlcoholDrinkers_Yes \\\n",
|
|||
|
"108769 False False ... False \n",
|
|||
|
"240750 False False ... True \n",
|
|||
|
"100329 False False ... False \n",
|
|||
|
"132628 False False ... True \n",
|
|||
|
"72101 False False ... True \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"119879 False False ... True \n",
|
|||
|
"103694 False False ... True \n",
|
|||
|
"131932 False False ... False \n",
|
|||
|
"146867 False False ... True \n",
|
|||
|
"121958 False False ... True \n",
|
|||
|
"\n",
|
|||
|
" HIVTesting_Yes FluVaxLast12_Yes PneumoVaxEver_Yes \\\n",
|
|||
|
"108769 True True True \n",
|
|||
|
"240750 False False False \n",
|
|||
|
"100329 False True True \n",
|
|||
|
"132628 False True False \n",
|
|||
|
"72101 False True True \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"119879 True True False \n",
|
|||
|
"103694 False True True \n",
|
|||
|
"131932 False False False \n",
|
|||
|
"146867 False True False \n",
|
|||
|
"121958 True True True \n",
|
|||
|
"\n",
|
|||
|
" TetanusLast10Tdap_Yes, received Tdap \\\n",
|
|||
|
"108769 False \n",
|
|||
|
"240750 False \n",
|
|||
|
"100329 False \n",
|
|||
|
"132628 False \n",
|
|||
|
"72101 True \n",
|
|||
|
"... ... \n",
|
|||
|
"119879 False \n",
|
|||
|
"103694 False \n",
|
|||
|
"131932 False \n",
|
|||
|
"146867 False \n",
|
|||
|
"121958 False \n",
|
|||
|
"\n",
|
|||
|
" TetanusLast10Tdap_Yes, received tetanus shot but not sure what type \\\n",
|
|||
|
"108769 True \n",
|
|||
|
"240750 False \n",
|
|||
|
"100329 True \n",
|
|||
|
"132628 False \n",
|
|||
|
"72101 False \n",
|
|||
|
"... ... \n",
|
|||
|
"119879 True \n",
|
|||
|
"103694 False \n",
|
|||
|
"131932 False \n",
|
|||
|
"146867 False \n",
|
|||
|
"121958 True \n",
|
|||
|
"\n",
|
|||
|
" TetanusLast10Tdap_Yes, received tetanus shot, but not Tdap \\\n",
|
|||
|
"108769 False \n",
|
|||
|
"240750 False \n",
|
|||
|
"100329 False \n",
|
|||
|
"132628 False \n",
|
|||
|
"72101 False \n",
|
|||
|
"... ... \n",
|
|||
|
"119879 False \n",
|
|||
|
"103694 False \n",
|
|||
|
"131932 False \n",
|
|||
|
"146867 False \n",
|
|||
|
"121958 False \n",
|
|||
|
"\n",
|
|||
|
" HighRiskLastYear_Yes \\\n",
|
|||
|
"108769 True \n",
|
|||
|
"240750 False \n",
|
|||
|
"100329 False \n",
|
|||
|
"132628 False \n",
|
|||
|
"72101 False \n",
|
|||
|
"... ... \n",
|
|||
|
"119879 False \n",
|
|||
|
"103694 False \n",
|
|||
|
"131932 False \n",
|
|||
|
"146867 False \n",
|
|||
|
"121958 False \n",
|
|||
|
"\n",
|
|||
|
" CovidPos_Tested positive using home test without a health professional \\\n",
|
|||
|
"108769 False \n",
|
|||
|
"240750 False \n",
|
|||
|
"100329 False \n",
|
|||
|
"132628 False \n",
|
|||
|
"72101 False \n",
|
|||
|
"... ... \n",
|
|||
|
"119879 False \n",
|
|||
|
"103694 False \n",
|
|||
|
"131932 False \n",
|
|||
|
"146867 False \n",
|
|||
|
"121958 False \n",
|
|||
|
"\n",
|
|||
|
" CovidPos_Yes \n",
|
|||
|
"108769 True \n",
|
|||
|
"240750 True \n",
|
|||
|
"100329 False \n",
|
|||
|
"132628 False \n",
|
|||
|
"72101 False \n",
|
|||
|
"... ... \n",
|
|||
|
"119879 False \n",
|
|||
|
"103694 False \n",
|
|||
|
"131932 False \n",
|
|||
|
"146867 True \n",
|
|||
|
"121958 False \n",
|
|||
|
"\n",
|
|||
|
"[196817 rows x 121 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 87,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"cat_features = get_filtered_columns(df, no_numeric=True)\n",
|
|||
|
"\n",
|
|||
|
"X_test = pd.get_dummies(X_test, columns=cat_features, drop_first=True)\n",
|
|||
|
"X_train = pd.get_dummies(X_train, columns=cat_features, drop_first=True)\n",
|
|||
|
"\n",
|
|||
|
"X_test\n",
|
|||
|
"X_train"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Определение перечня алгоритмов решения задачи регрессии"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 88,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: linear\n",
|
|||
|
"Model: linear_poly\n",
|
|||
|
"Model: linear_interact\n",
|
|||
|
"Model: ridge\n",
|
|||
|
"Model: decision_tree\n",
|
|||
|
"Model: knn\n",
|
|||
|
"Model: random_forest\n",
|
|||
|
"Model: mlp\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import math\n",
|
|||
|
"from sklearn.pipeline import make_pipeline\n",
|
|||
|
"from sklearn.preprocessing import PolynomialFeatures\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"models = {\n",
|
|||
|
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
|
|||
|
" \"linear_poly\": {\n",
|
|||
|
" \"model\": make_pipeline(\n",
|
|||
|
" PolynomialFeatures(degree=2),\n",
|
|||
|
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" \"linear_interact\": {\n",
|
|||
|
" \"model\": make_pipeline(\n",
|
|||
|
" PolynomialFeatures(interaction_only=True),\n",
|
|||
|
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
|
|||
|
" \"decision_tree\": {\n",
|
|||
|
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
|
|||
|
" },\n",
|
|||
|
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
|
|||
|
" \"random_forest\": {\n",
|
|||
|
" \"model\": ensemble.RandomForestRegressor(\n",
|
|||
|
" max_depth=7, random_state=random_state, n_jobs=-1\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" \"mlp\": {\n",
|
|||
|
" \"model\": neural_network.MLPRegressor(\n",
|
|||
|
" activation=\"tanh\",\n",
|
|||
|
" hidden_layer_sizes=(3),\n",
|
|||
|
" max_iter=500,\n",
|
|||
|
" early_stopping=True,\n",
|
|||
|
" random_state=random_state,\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"for model_name in models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
"\n",
|
|||
|
" fitted_model = models[model_name][\"model\"].fit(\n",
|
|||
|
" X_train.values, y_train.values.ravel()\n",
|
|||
|
" )\n",
|
|||
|
" y_train_pred = fitted_model.predict(X_train.values)\n",
|
|||
|
" y_test_pred = fitted_model.predict(X_test.values)\n",
|
|||
|
" models[model_name][\"fitted\"] = fitted_model\n",
|
|||
|
" models[model_name][\"train_preds\"] = y_train_pred\n",
|
|||
|
" models[model_name][\"preds\"] = y_test_pred\n",
|
|||
|
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
|
|||
|
" metrics.mean_squared_error(y_train, y_train_pred)\n",
|
|||
|
" )\n",
|
|||
|
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
|
|||
|
" metrics.mean_squared_error(y_test, y_test_pred)\n",
|
|||
|
" )\n",
|
|||
|
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
|
|||
|
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
|
|||
|
" )\n",
|
|||
|
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Выводим результаты оценки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 89,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_562d9_row0_col0 {\n",
|
|||
|
" background-color: #9dd93b;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row0_col1, #T_562d9_row1_col1, #T_562d9_row2_col1, #T_562d9_row7_col0 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row0_col2, #T_562d9_row1_col2, #T_562d9_row2_col2 {\n",
|
|||
|
" background-color: #5002a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row0_col3, #T_562d9_row1_col3, #T_562d9_row2_col3, #T_562d9_row7_col2 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row1_col0, #T_562d9_row2_col0 {\n",
|
|||
|
" background-color: #a0da39;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row3_col0 {\n",
|
|||
|
" background-color: #98d83e;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row3_col1 {\n",
|
|||
|
" background-color: #23888e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row3_col2, #T_562d9_row7_col3 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row3_col3 {\n",
|
|||
|
" background-color: #d35171;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row4_col0 {\n",
|
|||
|
" background-color: #48c16e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row4_col1, #T_562d9_row5_col1, #T_562d9_row6_col1 {\n",
|
|||
|
" background-color: #20928c;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row4_col2 {\n",
|
|||
|
" background-color: #6c00a8;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row4_col3, #T_562d9_row5_col3 {\n",
|
|||
|
" background-color: #ca457a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row5_col0 {\n",
|
|||
|
" background-color: #4ac16d;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row5_col2 {\n",
|
|||
|
" background-color: #6e00a8;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row6_col0, #T_562d9_row7_col1 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row6_col2 {\n",
|
|||
|
" background-color: #6001a6;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_562d9_row6_col3 {\n",
|
|||
|
" background-color: #c9447a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_562d9\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_562d9_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
|
|||
|
" <th id=\"T_562d9_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
|
|||
|
" <th id=\"T_562d9_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
|
|||
|
" <th id=\"T_562d9_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_562d9_level0_row0\" class=\"row_heading level0 row0\" >mlp</th>\n",
|
|||
|
" <td id=\"T_562d9_row0_col0\" class=\"data row0 col0\" >1.401571</td>\n",
|
|||
|
" <td id=\"T_562d9_row0_col1\" class=\"data row0 col1\" >1.401556</td>\n",
|
|||
|
" <td id=\"T_562d9_row0_col2\" class=\"data row0 col2\" >1.001832</td>\n",
|
|||
|
" <td id=\"T_562d9_row0_col3\" class=\"data row0 col3\" >0.049273</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_562d9_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
|
|||
|
" <td id=\"T_562d9_row1_col0\" class=\"data row1 col0\" >1.403185</td>\n",
|
|||
|
" <td id=\"T_562d9_row1_col1\" class=\"data row1 col1\" >1.401859</td>\n",
|
|||
|
" <td id=\"T_562d9_row1_col2\" class=\"data row1 col2\" >1.001885</td>\n",
|
|||
|
" <td id=\"T_562d9_row1_col3\" class=\"data row1 col3\" >0.048861</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_562d9_level0_row2\" class=\"row_heading level0 row2\" >linear</th>\n",
|
|||
|
" <td id=\"T_562d9_row2_col0\" class=\"data row2 col0\" >1.403184</td>\n",
|
|||
|
" <td id=\"T_562d9_row2_col1\" class=\"data row2 col1\" >1.401860</td>\n",
|
|||
|
" <td id=\"T_562d9_row2_col2\" class=\"data row2 col2\" >1.001898</td>\n",
|
|||
|
" <td id=\"T_562d9_row2_col3\" class=\"data row2 col3\" >0.048860</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_562d9_level0_row3\" class=\"row_heading level0 row3\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_562d9_row3_col0\" class=\"data row3 col0\" >1.400360</td>\n",
|
|||
|
" <td id=\"T_562d9_row3_col1\" class=\"data row3 col1\" >1.408185</td>\n",
|
|||
|
" <td id=\"T_562d9_row3_col2\" class=\"data row3 col2\" >1.001482</td>\n",
|
|||
|
" <td id=\"T_562d9_row3_col3\" class=\"data row3 col3\" >0.040258</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_562d9_level0_row4\" class=\"row_heading level0 row4\" >linear_poly</th>\n",
|
|||
|
" <td id=\"T_562d9_row4_col0\" class=\"data row4 col0\" >1.365912</td>\n",
|
|||
|
" <td id=\"T_562d9_row4_col1\" class=\"data row4 col1\" >1.416653</td>\n",
|
|||
|
" <td id=\"T_562d9_row4_col2\" class=\"data row4 col2\" >1.008370</td>\n",
|
|||
|
" <td id=\"T_562d9_row4_col3\" class=\"data row4 col3\" >0.028680</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_562d9_level0_row5\" class=\"row_heading level0 row5\" >linear_interact</th>\n",
|
|||
|
" <td id=\"T_562d9_row5_col0\" class=\"data row5 col0\" >1.366066</td>\n",
|
|||
|
" <td id=\"T_562d9_row5_col1\" class=\"data row5 col1\" >1.417008</td>\n",
|
|||
|
" <td id=\"T_562d9_row5_col2\" class=\"data row5 col2\" >1.008543</td>\n",
|
|||
|
" <td id=\"T_562d9_row5_col3\" class=\"data row5 col3\" >0.028193</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_562d9_level0_row6\" class=\"row_heading level0 row6\" >decision_tree</th>\n",
|
|||
|
" <td id=\"T_562d9_row6_col0\" class=\"data row6 col0\" >1.406026</td>\n",
|
|||
|
" <td id=\"T_562d9_row6_col1\" class=\"data row6 col1\" >1.417750</td>\n",
|
|||
|
" <td id=\"T_562d9_row6_col2\" class=\"data row6 col2\" >1.005576</td>\n",
|
|||
|
" <td id=\"T_562d9_row6_col3\" class=\"data row6 col3\" >0.027175</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_562d9_level0_row7\" class=\"row_heading level0 row7\" >knn</th>\n",
|
|||
|
" <td id=\"T_562d9_row7_col0\" class=\"data row7 col0\" >1.296492</td>\n",
|
|||
|
" <td id=\"T_562d9_row7_col1\" class=\"data row7 col1\" >1.493316</td>\n",
|
|||
|
" <td id=\"T_562d9_row7_col2\" class=\"data row7 col2\" >1.041495</td>\n",
|
|||
|
" <td id=\"T_562d9_row7_col3\" class=\"data row7 col3\" >-0.079292</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x13672dd7fe0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 89,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
|
|||
|
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
|
|||
|
"]\n",
|
|||
|
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
|
|||
|
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
|
|||
|
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Выводим лучшую модель"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 90,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'mlp'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
|
|||
|
"\n",
|
|||
|
"display(best_model)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Подбираем гиперпараметры методом поиска по сетке"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 92,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Fitting 3 folds for each of 8 candidates, totalling 24 fits\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"d:\\code\\AIM-PIbd-31-Potapov-N-S\\lab_4\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Лучшие параметры: {'max_depth': 10, 'min_samples_split': 5, 'n_estimators': 100}\n",
|
|||
|
"Лучший результат (MSE): 1.9866610870680514\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.ensemble import RandomForestRegressor\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"X = df[get_filtered_columns(df, no_numeric=True)]\n",
|
|||
|
"y = df[TARGET_COLUMN_NAME_REGRESSION] \n",
|
|||
|
"\n",
|
|||
|
"model = RandomForestRegressor() \n",
|
|||
|
"\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" 'n_estimators': [50, 100], \n",
|
|||
|
" 'max_depth': [10, 20], \n",
|
|||
|
" 'min_samples_split': [5, 10] \n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"grid_search = GridSearchCV(estimator=model, param_grid=param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
|
|||
|
"\n",
|
|||
|
"grid_search.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
|
|||
|
"print(\"Лучший результат (MSE):\", -grid_search.best_score_)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обучаем модель с новыми гиперпараметрами и сравниваем новых данных со старыми"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 93,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Fitting 3 folds for each of 8 candidates, totalling 24 fits\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"d:\\code\\AIM-PIbd-31-Potapov-N-S\\lab_4\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"d:\\code\\AIM-PIbd-31-Potapov-N-S\\lab_4\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"d:\\code\\AIM-PIbd-31-Potapov-N-S\\lab_4\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"d:\\code\\AIM-PIbd-31-Potapov-N-S\\lab_4\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"d:\\code\\AIM-PIbd-31-Potapov-N-S\\lab_4\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"d:\\code\\AIM-PIbd-31-Potapov-N-S\\lab_4\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Старые параметры: {'max_depth': 10, 'min_samples_split': 10, 'n_estimators': 100}\n",
|
|||
|
"Лучший результат (MSE) на старых параметрах: 1.9867639342405718\n",
|
|||
|
"\n",
|
|||
|
"Новые параметры: {'max_depth': 10, 'min_samples_split': 5, 'n_estimators': 100}\n",
|
|||
|
"Лучший результат (MSE) на новых параметрах: 1.990467882679972\n",
|
|||
|
"Среднеквадратическая ошибка (MSE) на тестовых данных: 1.975249119855746\n",
|
|||
|
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 1.4054355623278307\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Old data\n",
|
|||
|
"\n",
|
|||
|
"old_param_grid = param_grid\n",
|
|||
|
"old_grid_search = grid_search\n",
|
|||
|
"old_grid_search.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"old_best_params = old_grid_search.best_params_\n",
|
|||
|
"old_best_mse = -old_grid_search.best_score_ \n",
|
|||
|
"\n",
|
|||
|
"# New data\n",
|
|||
|
"\n",
|
|||
|
"new_param_grid = {\n",
|
|||
|
" 'n_estimators': [100],\n",
|
|||
|
" 'max_depth': [10],\n",
|
|||
|
" 'min_samples_split': [5]\n",
|
|||
|
" }\n",
|
|||
|
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
|
|||
|
" param_grid=new_param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', cv=2)\n",
|
|||
|
"\n",
|
|||
|
"new_grid_search.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"new_best_params = new_grid_search.best_params_\n",
|
|||
|
"new_best_mse = -new_grid_search.best_score_\n",
|
|||
|
"\n",
|
|||
|
"new_best_model = RandomForestRegressor(**new_best_params)\n",
|
|||
|
"new_best_model.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"old_best_model = RandomForestRegressor(**old_best_params)\n",
|
|||
|
"old_best_model.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
"y_new_pred = new_best_model.predict(X_test)\n",
|
|||
|
"y_old_pred = old_best_model.predict(X_test)\n",
|
|||
|
"\n",
|
|||
|
"mse = metrics.mean_squared_error(y_test, y_new_pred)\n",
|
|||
|
"rmse = np.sqrt(mse)\n",
|
|||
|
"\n",
|
|||
|
"print(\"Старые параметры:\", old_best_params)\n",
|
|||
|
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
|
|||
|
"print(\"\\nНовые параметры:\", new_best_params)\n",
|
|||
|
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
|
|||
|
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
|
|||
|
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Визуализация данных"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 94,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABRoAAAK9CAYAAABLm9DzAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeXhTZfr/8U/SlqYtSRDC3kZACm0FR8AFdFB0VGBEbWF0HB0VV1wAHUdFHQVxQ8VlREe+bgMuuEPRwXEbFRRkE1ltKrtpR7YWSUMhpW3y+6O/Zhqa0rRpm4S+X9fVC3LOc85zP0tO0rtnMfh8Pp8AAAAAAAAAIAzGSAcAAAAAAAAAIPaRaAQAAAAAAAAQNhKNAAAAAAAAAMJGohEAAAAAAABA2Eg0AgAAAAAAAAgbiUYAAAAAAAAAYSPRCAAAAAAAACBsJBoBAAAAAAAAhI1EIwAAAAAAAICwkWgEAAAAgBi1YMECrVmzxv96/vz5+vHHHyMXEACgVSPRCABAI23ZskXjxo1Tr169ZDKZZLFYdPrpp+vZZ5/VwYMHIx0eAKAVWL9+vW699VZt2rRJy5Yt04033ii32x3psAAArZTB5/P5Ih0EAACx5uOPP9bFF1+sxMREXXnllerXr58OHTqkxYsXa+7cuRo7dqxeeumlSIcJADjK7dmzR6eddpo2b94sSRo9erTmzp0b4agAAK0ViUYAABpo27ZtOuGEE5SamqqvvvpKXbt2DVi/efNmffzxx7r11lsjFCEAoDUpKyvThg0blJycrMzMzEiHAwBoxbh0GgCABnriiSe0f/9+vfrqq7WSjJLUu3fvgCSjwWDQ+PHjNWfOHPXt21cmk0mDBg3SN998E7Ddzz//rJtvvll9+/ZVUlKSOnTooIsvvljbt28PKDd79mwZDAb/T3Jysvr3769XXnkloNzYsWPVtm3bWvF98MEHMhgMWrhwYcDy5cuXa8SIEbJarUpOTtaZZ56pJUuWBJR54IEHZDAYVFRUFLD8+++/l8Fg0OzZswPq79GjR0C5goICJSUlyWAw1GrXJ598oqFDhyolJUVms1nnn39+SPcZO7w/Dv954IEHasWfn5+vSy65RBaLRR06dNCtt94qj8dTa99vvvmmBg0apKSkJLVv316XXnqpCgoKgsZRV/2H97PH49EDDzygPn36yGQyqWvXrho9erS2bNkiSdq+fXutvnS73Ro0aJB69uypHTt2+Jc/+eSTOu2009ShQwclJSVp0KBB+uCDDwLqKy4u1siRI5WamqrExER17dpVl19+uX7++eeAcqHsq7qd48ePr7V81KhRAeNd3Y4nn3yyVtl+/fpp2LBh/tcLFy6UwWAIWl+1w+fTlClTZDQa9eWXXwaUu+GGG9SmTRutXbu2zn1Vt6Pm3JCk6dOny2AwBMTWFNsfacyr++lIP2PHjpX0v7le873j9Xp1wgknBH3/hfr+HzZsmPr161er7JNPPlmrvh49emjUqFF19kv1WFbv3+FwKCkpSVdeeWVAucWLFysuLk6TJk2qc19S1Xs2KytLbdu2lcVi0eDBgzV//vyAMg2J/8MPP9T555+vbt26KTExUccdd5weeughVVZWBmwbbHyD9b8U2rGroeNx+BxauXKlfz4EizMxMVGDBg1SZmZmg+YxAABNLT7SAQAAEGv+9a9/qVevXjrttNNC3mbRokV69913NXHiRCUmJuqFF17QiBEjtGLFCv8vyCtXrtR3332nSy+9VKmpqdq+fbtmzpypYcOGKS8vT8nJyQH7fOaZZ2Sz2VRSUqJ//vOfuv7669WjRw+dc845DW7TV199pZEjR2rQoEH+BM6sWbN09tln69tvv9Upp5zS4H0GM3ny5KAJvTfeeENXXXWVhg8frscff1wHDhzQzJkz9dvf/larV6+ulbAM5sEHH1TPnj39r/fv36+bbropaNlLLrlEPXr00LRp07Rs2TLNmDFDv/76q15//XV/mUceeUT333+/LrnkEl133XXas2ePnnvuOZ1xxhlavXq12rVrV2u/5557rj+hsnLlSs2YMSNgfWVlpUaNGqUvv/xSl156qW699Va53W598cUX2rBhg4477rha+ywvL9eYMWPkdDq1ZMmSgOT2s88+qwsvvFCXX365Dh06pHfeeUcXX3yxFixYoPPPP1+SdOjQIZnNZt16663q0KGDtmzZoueee07r1q3T+vXrG7SvaHLffffpX//6l6699lqtX79eZrNZn332mV5++WU99NBD+s1vftOg/e3bt0/Tpk1rdDx1bV/fmJ9zzjl64403/OXnzZun3NzcgGXB5kW1N954I2Aco01mZqYeeugh3XnnnfrDH/6gCy+8UKWlpRo7dqwyMjL04IMPHnH70tJS5eTkqEePHjp48KBmz56tMWPGaOnSpY06Ls2ePVtt27bV7bffrrZt2+qrr77S5MmTVVJSounTpzd4f01x7ApFfQnZauHOYwAAwuYDAAAhc7lcPkm+iy66KORtJPkk+b7//nv/sp9//tlnMpl8OTk5/mUHDhyote3SpUt9knyvv/66f9msWbN8knzbtm3zL9u4caNPku+JJ57wL7vqqqt8KSkptfb5/vvv+yT5vv76a5/P5/N5vV5fenq6b/jw4T6v1xsQT8+ePX3nnnuuf9mUKVN8knx79uwJ2OfKlSt9knyzZs0KqP/YY4/1v96wYYPPaDT6Ro4cGRC/2+32tWvXznf99dcH7HPnzp0+q9Vaa/nhqvtj5cqVAcv37Nnjk+SbMmVKrfgvvPDCgLI333yzT5Jv7dq1Pp/P59u+fbsvLi7O98gjjwSUW79+vS8+Pr7W8kOHDvkk+caPH+9fdng/+3w+3z//+U+fJN/TTz9dqx3Vfb9t2zZ/X3q9Xt/ll1/uS05O9i1fvrzWNofPmUOHDvn69evnO/vss2uVremJJ57wSfIVFRU1eF+SfLfcckutfZ5//vkB413djunTp9cqe/zxx/vOPPNM/+uvv/7aJ8n3/vvv1xnz4fPJ56sajzZt2viuu+4636+//urr3r2776STTvKVl5fXuZ+a7ag5N+666y5fp06dfIMGDQqILdztQxnzmqrnaDCHv/c9Ho/Pbrf731OHv/9Cef/7fD7fmWee6Tv++ONrlZ0+fXqtY82xxx7rO//884PG5/P9byxr7r+ystL329/+1te5c2dfUVGR75ZbbvHFx8fXes+GYvfu3T5JvieffLJR8Qc7zo4bN86XnJzs83g8/mUGg8E3efLkgHKH939Djl0NHY+ac+jf//63T5JvxIgRteZGuPMYAICmxqXTAAA0QElJiSTJbDY3aLshQ4Zo0KBB/td2u10XXXSRPvvsM/8le0lJSf715eXlKi4uVu/evdWuXTv98MMPtfb566+/qqioSFu3btUzzzyjuLg4nXnmmbXKFRUVBfwc/jTSNWvWaNOmTbrssstUXFzsL1daWqrf/e53+uabb+T1egO22bt3b8A+XS5XvX1wzz33aODAgbr44osDln/xxRfat2+f/vSnPwXsMy4uTqeeeqq+/vrrevfdULfcckvA6wkTJkiS/v3vf0uqOqvM6/XqkksuCYipS5cuSk9PrxVT9VmaJpPpiPXOnTtXNpvNX19Nh18SKUl33nmn5syZo/feey/o2Vs158yvv/4ql8uloUOHBp0vbrdbu3fv1tKlS/X222/r+OOPV/v27Ru1L4/HU2telZeXB23zgQMHapU9/DLVmjEWFRVp3759Qdcfrl+/fpo6dapeeeUVDR8+XEVFRXrttdcUH9+wi3b++9//6rnnntP9998f9PLWcLZv6Jg3xD/+8Q8VFxdrypQpdZap7/1frbKyslbZAwcOBC1bXl6uoqIiFRcXq6Kiot44jUajZs+erf3792vkyJF64YUXdM899+ikk04KqZ3V9W3ZskWPPfaYjEajTj/99EbFX3OeV8+3oUOH6sCBA8rPz/ev69SpkwoLC48YV2OOXaGORzWfz6d77rlHY8aM0amnnnrEsuHOYwAAmgKXTgMA0AAWi0WS6v3l8HDp6em1lvXp00cHDhzQnj171KVLFx0
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1600x800 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"plt.figure(figsize=(16, 8))\n",
|
|||
|
"plt.scatter(range(len(y_test)), y_test, label=\"Истинные значения\", color=\"black\", alpha=0.5)\n",
|
|||
|
"plt.scatter(range(len(y_test)), y_new_pred, label=\"Предсказанные (новые параметры)\", color=\"blue\", alpha=0.5)\n",
|
|||
|
"plt.scatter(range(len(y_test)), y_old_pred, label=\"Предсказанные (старые параметры)\", color=\"red\", alpha=0.5)\n",
|
|||
|
"plt.xlabel(\"Выборка\")\n",
|
|||
|
"plt.ylabel(\"Значения\")\n",
|
|||
|
"plt.legend()\n",
|
|||
|
"plt.title(\"Сравнение предсказанных и истинных значений\")\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": ".venv",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.12.6"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|