AIM-PIbd-31-Potapov-N-S/lab_4/lab4.ipynb

5161 lines
594 KiB
Plaintext
Raw Normal View History

2024-12-21 04:56:36 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Лабораторная 4\n",
"Датасет: Набор данных для анализа и прогнозирования сердечного приступа"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['State', 'Sex', 'GeneralHealth', 'PhysicalHealthDays',\n",
" 'MentalHealthDays', 'LastCheckupTime', 'PhysicalActivities',\n",
" 'SleepHours', 'RemovedTeeth', 'HadHeartAttack', 'HadAngina',\n",
" 'HadStroke', 'HadAsthma', 'HadSkinCancer', 'HadCOPD',\n",
" 'HadDepressiveDisorder', 'HadKidneyDisease', 'HadArthritis',\n",
" 'HadDiabetes', 'DeafOrHardOfHearing', 'BlindOrVisionDifficulty',\n",
" 'DifficultyConcentrating', 'DifficultyWalking',\n",
" 'DifficultyDressingBathing', 'DifficultyErrands', 'SmokerStatus',\n",
" 'ECigaretteUsage', 'ChestScan', 'RaceEthnicityCategory', 'AgeCategory',\n",
" 'HeightInMeters', 'WeightInKilograms', 'BMI', 'AlcoholDrinkers',\n",
" 'HIVTesting', 'FluVaxLast12', 'PneumoVaxEver', 'TetanusLast10Tdap',\n",
" 'HighRiskLastYear', 'CovidPos'],\n",
" dtype='object')\n"
]
}
],
"source": [
"import pandas as pd\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"df = pd.read_csv(\"csv\\\\heart_2022_no_nans.csv\")\n",
"print(df.columns)\n",
"map_heart_disease_to_int = {'No': 0, 'Yes': 1}\n",
"\n",
"TARGET_COLUMN_NAME_CLASSIFICATION = 'HadHeartAttack'\n",
"\n",
"df[TARGET_COLUMN_NAME_CLASSIFICATION] = df[TARGET_COLUMN_NAME_CLASSIFICATION].map(map_heart_disease_to_int).astype('int32')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Классификация"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Бизнес цель 1: \n",
"Предсказание сердечного приступа (HadHeartAttack) на основе других факторов."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формируем выборки"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>State</th>\n",
" <th>Sex</th>\n",
" <th>GeneralHealth</th>\n",
" <th>PhysicalHealthDays</th>\n",
" <th>MentalHealthDays</th>\n",
" <th>LastCheckupTime</th>\n",
" <th>PhysicalActivities</th>\n",
" <th>SleepHours</th>\n",
" <th>RemovedTeeth</th>\n",
" <th>HadHeartAttack</th>\n",
" <th>...</th>\n",
" <th>HeightInMeters</th>\n",
" <th>WeightInKilograms</th>\n",
" <th>BMI</th>\n",
" <th>AlcoholDrinkers</th>\n",
" <th>HIVTesting</th>\n",
" <th>FluVaxLast12</th>\n",
" <th>PneumoVaxEver</th>\n",
" <th>TetanusLast10Tdap</th>\n",
" <th>HighRiskLastYear</th>\n",
" <th>CovidPos</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6432</th>\n",
" <td>Arizona</td>\n",
" <td>Male</td>\n",
" <td>Very good</td>\n",
" <td>0.0</td>\n",
" <td>5.0</td>\n",
" <td>Within past 5 years (2 years but less than 5 y...</td>\n",
" <td>Yes</td>\n",
" <td>8.0</td>\n",
" <td>None of them</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.88</td>\n",
" <td>77.11</td>\n",
" <td>21.83</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61767</th>\n",
" <td>Indiana</td>\n",
" <td>Female</td>\n",
" <td>Very good</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>6.0</td>\n",
" <td>None of them</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.73</td>\n",
" <td>77.11</td>\n",
" <td>25.85</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes, received Tdap</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>102005</th>\n",
" <td>Michigan</td>\n",
" <td>Male</td>\n",
" <td>Very good</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>7.0</td>\n",
" <td>None of them</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.85</td>\n",
" <td>83.46</td>\n",
" <td>24.28</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes, received Tdap</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>183791</th>\n",
" <td>South Dakota</td>\n",
" <td>Female</td>\n",
" <td>Good</td>\n",
" <td>10.0</td>\n",
" <td>5.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>7.0</td>\n",
" <td>None of them</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.75</td>\n",
" <td>81.65</td>\n",
" <td>26.58</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes, received tetanus shot but not sure what type</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>230656</th>\n",
" <td>West Virginia</td>\n",
" <td>Female</td>\n",
" <td>Good</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>No</td>\n",
" <td>8.0</td>\n",
" <td>6 or more, but not all</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.55</td>\n",
" <td>68.04</td>\n",
" <td>28.34</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>93877</th>\n",
" <td>Maryland</td>\n",
" <td>Female</td>\n",
" <td>Very good</td>\n",
" <td>0.0</td>\n",
" <td>12.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>No</td>\n",
" <td>6.0</td>\n",
" <td>6 or more, but not all</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.65</td>\n",
" <td>113.40</td>\n",
" <td>41.60</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes, received Tdap</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>117856</th>\n",
" <td>Missouri</td>\n",
" <td>Male</td>\n",
" <td>Good</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>8.0</td>\n",
" <td>1 to 5</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.80</td>\n",
" <td>117.93</td>\n",
" <td>36.26</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes, received tetanus shot but not sure what type</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>41922</th>\n",
" <td>Georgia</td>\n",
" <td>Male</td>\n",
" <td>Very good</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>7.0</td>\n",
" <td>None of them</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.78</td>\n",
" <td>113.40</td>\n",
" <td>35.87</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes, received tetanus shot but not sure what type</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>98221</th>\n",
" <td>Massachusetts</td>\n",
" <td>Female</td>\n",
" <td>Good</td>\n",
" <td>5.0</td>\n",
" <td>20.0</td>\n",
" <td>Within past 2 years (1 year but less than 2 ye...</td>\n",
" <td>No</td>\n",
" <td>5.0</td>\n",
" <td>None of them</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.70</td>\n",
" <td>90.72</td>\n",
" <td>31.32</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Yes, received Tdap</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>151717</th>\n",
" <td>New York</td>\n",
" <td>Male</td>\n",
" <td>Very good</td>\n",
" <td>2.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>7.0</td>\n",
" <td>None of them</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.73</td>\n",
" <td>68.95</td>\n",
" <td>23.11</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes, received Tdap</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>196817 rows × 40 columns</p>\n",
"</div>"
],
"text/plain": [
" State Sex GeneralHealth PhysicalHealthDays \\\n",
"6432 Arizona Male Very good 0.0 \n",
"61767 Indiana Female Very good 0.0 \n",
"102005 Michigan Male Very good 0.0 \n",
"183791 South Dakota Female Good 10.0 \n",
"230656 West Virginia Female Good 0.0 \n",
"... ... ... ... ... \n",
"93877 Maryland Female Very good 0.0 \n",
"117856 Missouri Male Good 0.0 \n",
"41922 Georgia Male Very good 0.0 \n",
"98221 Massachusetts Female Good 5.0 \n",
"151717 New York Male Very good 2.0 \n",
"\n",
" MentalHealthDays LastCheckupTime \\\n",
"6432 5.0 Within past 5 years (2 years but less than 5 y... \n",
"61767 0.0 Within past year (anytime less than 12 months ... \n",
"102005 0.0 Within past year (anytime less than 12 months ... \n",
"183791 5.0 Within past year (anytime less than 12 months ... \n",
"230656 0.0 Within past year (anytime less than 12 months ... \n",
"... ... ... \n",
"93877 12.0 Within past year (anytime less than 12 months ... \n",
"117856 0.0 Within past year (anytime less than 12 months ... \n",
"41922 0.0 Within past year (anytime less than 12 months ... \n",
"98221 20.0 Within past 2 years (1 year but less than 2 ye... \n",
"151717 0.0 Within past year (anytime less than 12 months ... \n",
"\n",
" PhysicalActivities SleepHours RemovedTeeth HadHeartAttack \\\n",
"6432 Yes 8.0 None of them 0 \n",
"61767 Yes 6.0 None of them 0 \n",
"102005 Yes 7.0 None of them 0 \n",
"183791 Yes 7.0 None of them 0 \n",
"230656 No 8.0 6 or more, but not all 0 \n",
"... ... ... ... ... \n",
"93877 No 6.0 6 or more, but not all 0 \n",
"117856 Yes 8.0 1 to 5 0 \n",
"41922 Yes 7.0 None of them 0 \n",
"98221 No 5.0 None of them 0 \n",
"151717 Yes 7.0 None of them 0 \n",
"\n",
" ... HeightInMeters WeightInKilograms BMI AlcoholDrinkers \\\n",
"6432 ... 1.88 77.11 21.83 Yes \n",
"61767 ... 1.73 77.11 25.85 Yes \n",
"102005 ... 1.85 83.46 24.28 Yes \n",
"183791 ... 1.75 81.65 26.58 No \n",
"230656 ... 1.55 68.04 28.34 No \n",
"... ... ... ... ... ... \n",
"93877 ... 1.65 113.40 41.60 No \n",
"117856 ... 1.80 117.93 36.26 No \n",
"41922 ... 1.78 113.40 35.87 Yes \n",
"98221 ... 1.70 90.72 31.32 Yes \n",
"151717 ... 1.73 68.95 23.11 Yes \n",
"\n",
" HIVTesting FluVaxLast12 PneumoVaxEver \\\n",
"6432 Yes Yes No \n",
"61767 Yes No No \n",
"102005 No No No \n",
"183791 Yes No No \n",
"230656 No No No \n",
"... ... ... ... \n",
"93877 No No No \n",
"117856 No No No \n",
"41922 No Yes No \n",
"98221 Yes No Yes \n",
"151717 Yes Yes No \n",
"\n",
" TetanusLast10Tdap HighRiskLastYear \\\n",
"6432 No, did not receive any tetanus shot in the pa... No \n",
"61767 Yes, received Tdap No \n",
"102005 Yes, received Tdap No \n",
"183791 Yes, received tetanus shot but not sure what type No \n",
"230656 No, did not receive any tetanus shot in the pa... No \n",
"... ... ... \n",
"93877 Yes, received Tdap No \n",
"117856 Yes, received tetanus shot but not sure what type No \n",
"41922 Yes, received tetanus shot but not sure what type No \n",
"98221 Yes, received Tdap No \n",
"151717 Yes, received Tdap No \n",
"\n",
" CovidPos \n",
"6432 Yes \n",
"61767 No \n",
"102005 Yes \n",
"183791 No \n",
"230656 No \n",
"... ... \n",
"93877 Yes \n",
"117856 Yes \n",
"41922 No \n",
"98221 No \n",
"151717 Yes \n",
"\n",
"[196817 rows x 40 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HadHeartAttack</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6432</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61767</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>102005</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>183791</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>230656</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>93877</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>117856</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>41922</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>98221</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>151717</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>196817 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" HadHeartAttack\n",
"6432 0\n",
"61767 0\n",
"102005 0\n",
"183791 0\n",
"230656 0\n",
"... ...\n",
"93877 0\n",
"117856 0\n",
"41922 0\n",
"98221 0\n",
"151717 0\n",
"\n",
"[196817 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>State</th>\n",
" <th>Sex</th>\n",
" <th>GeneralHealth</th>\n",
" <th>PhysicalHealthDays</th>\n",
" <th>MentalHealthDays</th>\n",
" <th>LastCheckupTime</th>\n",
" <th>PhysicalActivities</th>\n",
" <th>SleepHours</th>\n",
" <th>RemovedTeeth</th>\n",
" <th>HadHeartAttack</th>\n",
" <th>...</th>\n",
" <th>HeightInMeters</th>\n",
" <th>WeightInKilograms</th>\n",
" <th>BMI</th>\n",
" <th>AlcoholDrinkers</th>\n",
" <th>HIVTesting</th>\n",
" <th>FluVaxLast12</th>\n",
" <th>PneumoVaxEver</th>\n",
" <th>TetanusLast10Tdap</th>\n",
" <th>HighRiskLastYear</th>\n",
" <th>CovidPos</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>108080</th>\n",
" <td>Minnesota</td>\n",
" <td>Female</td>\n",
" <td>Very good</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>7.0</td>\n",
" <td>None of them</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.68</td>\n",
" <td>81.65</td>\n",
" <td>29.05</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes, received tetanus shot, but not Tdap</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>109629</th>\n",
" <td>Minnesota</td>\n",
" <td>Female</td>\n",
" <td>Very good</td>\n",
" <td>1.0</td>\n",
" <td>15.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>6.0</td>\n",
" <td>None of them</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.68</td>\n",
" <td>99.79</td>\n",
" <td>35.51</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes, received tetanus shot but not sure what type</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>24640</th>\n",
" <td>Connecticut</td>\n",
" <td>Male</td>\n",
" <td>Good</td>\n",
" <td>15.0</td>\n",
" <td>5.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>7.0</td>\n",
" <td>6 or more, but not all</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.70</td>\n",
" <td>72.57</td>\n",
" <td>25.06</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes, received tetanus shot but not sure what type</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12715</th>\n",
" <td>Arkansas</td>\n",
" <td>Female</td>\n",
" <td>Good</td>\n",
" <td>8.0</td>\n",
" <td>30.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>7.0</td>\n",
" <td>1 to 5</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.63</td>\n",
" <td>86.18</td>\n",
" <td>32.61</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes, received Tdap</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>162549</th>\n",
" <td>Ohio</td>\n",
" <td>Female</td>\n",
" <td>Excellent</td>\n",
" <td>0.0</td>\n",
" <td>7.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>4.0</td>\n",
" <td>None of them</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.60</td>\n",
" <td>81.19</td>\n",
" <td>31.71</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes, received Tdap</td>\n",
" <td>Yes</td>\n",
" <td>Tested positive using home test without a heal...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>187130</th>\n",
" <td>South Dakota</td>\n",
" <td>Male</td>\n",
" <td>Poor</td>\n",
" <td>30.0</td>\n",
" <td>30.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>No</td>\n",
" <td>4.0</td>\n",
" <td>None of them</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.83</td>\n",
" <td>97.98</td>\n",
" <td>29.29</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38512</th>\n",
" <td>Florida</td>\n",
" <td>Male</td>\n",
" <td>Excellent</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past 5 years (2 years but less than 5 y...</td>\n",
" <td>Yes</td>\n",
" <td>8.0</td>\n",
" <td>None of them</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.83</td>\n",
" <td>104.33</td>\n",
" <td>31.19</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes, received tetanus shot but not sure what type</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>125776</th>\n",
" <td>Nebraska</td>\n",
" <td>Male</td>\n",
" <td>Fair</td>\n",
" <td>1.0</td>\n",
" <td>2.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>No</td>\n",
" <td>6.0</td>\n",
" <td>1 to 5</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.73</td>\n",
" <td>92.99</td>\n",
" <td>31.17</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Yes, received tetanus shot but not sure what type</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33614</th>\n",
" <td>Florida</td>\n",
" <td>Female</td>\n",
" <td>Good</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>7.0</td>\n",
" <td>None of them</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.60</td>\n",
" <td>65.77</td>\n",
" <td>25.69</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Yes, received Tdap</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>223067</th>\n",
" <td>Washington</td>\n",
" <td>Male</td>\n",
" <td>Excellent</td>\n",
" <td>0.0</td>\n",
" <td>2.0</td>\n",
" <td>Within past 2 years (1 year but less than 2 ye...</td>\n",
" <td>Yes</td>\n",
" <td>7.0</td>\n",
" <td>1 to 5</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.75</td>\n",
" <td>70.00</td>\n",
" <td>22.86</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes, received Tdap</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>49205 rows × 40 columns</p>\n",
"</div>"
],
"text/plain": [
" State Sex GeneralHealth PhysicalHealthDays \\\n",
"108080 Minnesota Female Very good 0.0 \n",
"109629 Minnesota Female Very good 1.0 \n",
"24640 Connecticut Male Good 15.0 \n",
"12715 Arkansas Female Good 8.0 \n",
"162549 Ohio Female Excellent 0.0 \n",
"... ... ... ... ... \n",
"187130 South Dakota Male Poor 30.0 \n",
"38512 Florida Male Excellent 0.0 \n",
"125776 Nebraska Male Fair 1.0 \n",
"33614 Florida Female Good 0.0 \n",
"223067 Washington Male Excellent 0.0 \n",
"\n",
" MentalHealthDays LastCheckupTime \\\n",
"108080 0.0 Within past year (anytime less than 12 months ... \n",
"109629 15.0 Within past year (anytime less than 12 months ... \n",
"24640 5.0 Within past year (anytime less than 12 months ... \n",
"12715 30.0 Within past year (anytime less than 12 months ... \n",
"162549 7.0 Within past year (anytime less than 12 months ... \n",
"... ... ... \n",
"187130 30.0 Within past year (anytime less than 12 months ... \n",
"38512 0.0 Within past 5 years (2 years but less than 5 y... \n",
"125776 2.0 Within past year (anytime less than 12 months ... \n",
"33614 0.0 Within past year (anytime less than 12 months ... \n",
"223067 2.0 Within past 2 years (1 year but less than 2 ye... \n",
"\n",
" PhysicalActivities SleepHours RemovedTeeth HadHeartAttack \\\n",
"108080 Yes 7.0 None of them 0 \n",
"109629 Yes 6.0 None of them 0 \n",
"24640 Yes 7.0 6 or more, but not all 0 \n",
"12715 Yes 7.0 1 to 5 0 \n",
"162549 Yes 4.0 None of them 0 \n",
"... ... ... ... ... \n",
"187130 No 4.0 None of them 0 \n",
"38512 Yes 8.0 None of them 0 \n",
"125776 No 6.0 1 to 5 0 \n",
"33614 Yes 7.0 None of them 0 \n",
"223067 Yes 7.0 1 to 5 0 \n",
"\n",
" ... HeightInMeters WeightInKilograms BMI AlcoholDrinkers \\\n",
"108080 ... 1.68 81.65 29.05 Yes \n",
"109629 ... 1.68 99.79 35.51 Yes \n",
"24640 ... 1.70 72.57 25.06 Yes \n",
"12715 ... 1.63 86.18 32.61 Yes \n",
"162549 ... 1.60 81.19 31.71 Yes \n",
"... ... ... ... ... ... \n",
"187130 ... 1.83 97.98 29.29 Yes \n",
"38512 ... 1.83 104.33 31.19 Yes \n",
"125776 ... 1.73 92.99 31.17 No \n",
"33614 ... 1.60 65.77 25.69 Yes \n",
"223067 ... 1.75 70.00 22.86 Yes \n",
"\n",
" HIVTesting FluVaxLast12 PneumoVaxEver \\\n",
"108080 No No No \n",
"109629 No Yes No \n",
"24640 Yes Yes Yes \n",
"12715 Yes Yes No \n",
"162549 Yes Yes No \n",
"... ... ... ... \n",
"187130 No No No \n",
"38512 No No No \n",
"125776 Yes No Yes \n",
"33614 No No Yes \n",
"223067 Yes Yes No \n",
"\n",
" TetanusLast10Tdap HighRiskLastYear \\\n",
"108080 Yes, received tetanus shot, but not Tdap No \n",
"109629 Yes, received tetanus shot but not sure what type No \n",
"24640 Yes, received tetanus shot but not sure what type No \n",
"12715 Yes, received Tdap No \n",
"162549 Yes, received Tdap Yes \n",
"... ... ... \n",
"187130 No, did not receive any tetanus shot in the pa... No \n",
"38512 Yes, received tetanus shot but not sure what type No \n",
"125776 Yes, received tetanus shot but not sure what type No \n",
"33614 Yes, received Tdap No \n",
"223067 Yes, received Tdap No \n",
"\n",
" CovidPos \n",
"108080 Yes \n",
"109629 No \n",
"24640 No \n",
"12715 No \n",
"162549 Tested positive using home test without a heal... \n",
"... ... \n",
"187130 No \n",
"38512 No \n",
"125776 Yes \n",
"33614 No \n",
"223067 No \n",
"\n",
"[49205 rows x 40 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HadHeartAttack</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>108080</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>109629</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>24640</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12715</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>162549</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>187130</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38512</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>125776</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33614</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>223067</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>49205 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" HadHeartAttack\n",
"108080 0\n",
"109629 0\n",
"24640 0\n",
"12715 0\n",
"162549 0\n",
"... ...\n",
"187130 0\n",
"38512 0\n",
"125776 0\n",
"33614 0\n",
"223067 0\n",
"\n",
"[49205 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" X = df_input # Contains all columns.\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
" # Split original dataframe into train and temp dataframes.\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
" # Split the temp dataframe into val and test dataframes.\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=TARGET_COLUMN_NAME_CLASSIFICATION, frac_train=0.80, frac_val=0, frac_test=0.20, random_state=9\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Пропущенные значения по столбцам:\n",
"State 0\n",
"Sex 0\n",
"GeneralHealth 0\n",
"PhysicalHealthDays 0\n",
"MentalHealthDays 0\n",
"LastCheckupTime 0\n",
"PhysicalActivities 0\n",
"SleepHours 0\n",
"RemovedTeeth 0\n",
"HadHeartAttack 0\n",
"HadAngina 0\n",
"HadStroke 0\n",
"HadAsthma 0\n",
"HadSkinCancer 0\n",
"HadCOPD 0\n",
"HadDepressiveDisorder 0\n",
"HadKidneyDisease 0\n",
"HadArthritis 0\n",
"HadDiabetes 0\n",
"DeafOrHardOfHearing 0\n",
"BlindOrVisionDifficulty 0\n",
"DifficultyConcentrating 0\n",
"DifficultyWalking 0\n",
"DifficultyDressingBathing 0\n",
"DifficultyErrands 0\n",
"SmokerStatus 0\n",
"ECigaretteUsage 0\n",
"ChestScan 0\n",
"RaceEthnicityCategory 0\n",
"AgeCategory 0\n",
"HeightInMeters 0\n",
"WeightInKilograms 0\n",
"BMI 0\n",
"AlcoholDrinkers 0\n",
"HIVTesting 0\n",
"FluVaxLast12 0\n",
"PneumoVaxEver 0\n",
"TetanusLast10Tdap 0\n",
"HighRiskLastYear 0\n",
"CovidPos 0\n",
"dtype: int64\n",
"\n",
"Статистический обзор данных:\n",
" PhysicalHealthDays MentalHealthDays SleepHours HadHeartAttack \\\n",
"count 246022.000000 246022.000000 246022.000000 246022.000000 \n",
"mean 4.119026 4.167140 7.021331 0.054609 \n",
"std 8.405844 8.102687 1.440681 0.227216 \n",
"min 0.000000 0.000000 1.000000 0.000000 \n",
"25% 0.000000 0.000000 6.000000 0.000000 \n",
"50% 0.000000 0.000000 7.000000 0.000000 \n",
"75% 3.000000 4.000000 8.000000 0.000000 \n",
"max 30.000000 30.000000 24.000000 1.000000 \n",
"\n",
" HeightInMeters WeightInKilograms BMI \n",
"count 246022.000000 246022.000000 246022.000000 \n",
"mean 1.705150 83.615179 28.668136 \n",
"std 0.106654 21.323156 6.513973 \n",
"min 0.910000 28.120000 12.020000 \n",
"25% 1.630000 68.040000 24.270000 \n",
"50% 1.700000 81.650000 27.460000 \n",
"75% 1.780000 95.250000 31.890000 \n",
"max 2.410000 292.570000 97.650000 \n"
]
}
],
"source": [
"null_values = df.isnull().sum()\n",
"print(\"Пропущенные значения по столбцам:\")\n",
"print(null_values)\n",
"\n",
"stat_summary = df.describe()\n",
"print(\"\\nСтатистический обзор данных:\")\n",
"print(stat_summary)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формируем конвеер для классификации данных и проверка конвеера"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>PhysicalHealthDays</th>\n",
" <th>MentalHealthDays</th>\n",
" <th>SleepHours</th>\n",
" <th>HadHeartAttack</th>\n",
" <th>HeightInMeters</th>\n",
" <th>WeightInKilograms</th>\n",
" <th>BMI</th>\n",
" <th>State_Alaska</th>\n",
" <th>State_Arizona</th>\n",
" <th>State_Arkansas</th>\n",
" <th>...</th>\n",
" <th>AlcoholDrinkers_Yes</th>\n",
" <th>HIVTesting_Yes</th>\n",
" <th>FluVaxLast12_Yes</th>\n",
" <th>PneumoVaxEver_Yes</th>\n",
" <th>TetanusLast10Tdap_Yes, received Tdap</th>\n",
" <th>TetanusLast10Tdap_Yes, received tetanus shot but not sure what type</th>\n",
" <th>TetanusLast10Tdap_Yes, received tetanus shot, but not Tdap</th>\n",
" <th>HighRiskLastYear_Yes</th>\n",
" <th>CovidPos_Tested positive using home test without a health professional</th>\n",
" <th>CovidPos_Yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6432</th>\n",
" <td>-0.490179</td>\n",
" <td>0.103124</td>\n",
" <td>0.677965</td>\n",
" <td>-0.24034</td>\n",
" <td>1.639362</td>\n",
" <td>-0.304540</td>\n",
" <td>-1.051314</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61767</th>\n",
" <td>-0.490179</td>\n",
" <td>-0.513985</td>\n",
" <td>-0.708460</td>\n",
" <td>-0.24034</td>\n",
" <td>0.233664</td>\n",
" <td>-0.304540</td>\n",
" <td>-0.432966</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>102005</th>\n",
" <td>-0.490179</td>\n",
" <td>-0.513985</td>\n",
" <td>-0.015247</td>\n",
" <td>-0.24034</td>\n",
" <td>1.358222</td>\n",
" <td>-0.006656</td>\n",
" <td>-0.674460</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>183791</th>\n",
" <td>0.699048</td>\n",
" <td>0.103124</td>\n",
" <td>-0.015247</td>\n",
" <td>-0.24034</td>\n",
" <td>0.421091</td>\n",
" <td>-0.091564</td>\n",
" <td>-0.320678</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>230656</th>\n",
" <td>-0.490179</td>\n",
" <td>-0.513985</td>\n",
" <td>0.677965</td>\n",
" <td>-0.24034</td>\n",
" <td>-1.453173</td>\n",
" <td>-0.730021</td>\n",
" <td>-0.049959</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>93877</th>\n",
" <td>-0.490179</td>\n",
" <td>0.967076</td>\n",
" <td>-0.708460</td>\n",
" <td>-0.24034</td>\n",
" <td>-0.516041</td>\n",
" <td>1.397856</td>\n",
" <td>1.989666</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>117856</th>\n",
" <td>-0.490179</td>\n",
" <td>-0.513985</td>\n",
" <td>0.677965</td>\n",
" <td>-0.24034</td>\n",
" <td>0.889656</td>\n",
" <td>1.610362</td>\n",
" <td>1.168279</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>41922</th>\n",
" <td>-0.490179</td>\n",
" <td>-0.513985</td>\n",
" <td>-0.015247</td>\n",
" <td>-0.24034</td>\n",
" <td>0.702230</td>\n",
" <td>1.397856</td>\n",
" <td>1.108290</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>98221</th>\n",
" <td>0.104435</td>\n",
" <td>1.954450</td>\n",
" <td>-1.401672</td>\n",
" <td>-0.24034</td>\n",
" <td>-0.047475</td>\n",
" <td>0.333917</td>\n",
" <td>0.408418</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>151717</th>\n",
" <td>-0.252334</td>\n",
" <td>-0.513985</td>\n",
" <td>-0.015247</td>\n",
" <td>-0.24034</td>\n",
" <td>0.233664</td>\n",
" <td>-0.687332</td>\n",
" <td>-0.854427</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>196817 rows × 109 columns</p>\n",
"</div>"
],
"text/plain": [
" PhysicalHealthDays MentalHealthDays SleepHours HadHeartAttack \\\n",
"6432 -0.490179 0.103124 0.677965 -0.24034 \n",
"61767 -0.490179 -0.513985 -0.708460 -0.24034 \n",
"102005 -0.490179 -0.513985 -0.015247 -0.24034 \n",
"183791 0.699048 0.103124 -0.015247 -0.24034 \n",
"230656 -0.490179 -0.513985 0.677965 -0.24034 \n",
"... ... ... ... ... \n",
"93877 -0.490179 0.967076 -0.708460 -0.24034 \n",
"117856 -0.490179 -0.513985 0.677965 -0.24034 \n",
"41922 -0.490179 -0.513985 -0.015247 -0.24034 \n",
"98221 0.104435 1.954450 -1.401672 -0.24034 \n",
"151717 -0.252334 -0.513985 -0.015247 -0.24034 \n",
"\n",
" HeightInMeters WeightInKilograms BMI State_Alaska \\\n",
"6432 1.639362 -0.304540 -1.051314 0.0 \n",
"61767 0.233664 -0.304540 -0.432966 0.0 \n",
"102005 1.358222 -0.006656 -0.674460 0.0 \n",
"183791 0.421091 -0.091564 -0.320678 0.0 \n",
"230656 -1.453173 -0.730021 -0.049959 0.0 \n",
"... ... ... ... ... \n",
"93877 -0.516041 1.397856 1.989666 0.0 \n",
"117856 0.889656 1.610362 1.168279 0.0 \n",
"41922 0.702230 1.397856 1.108290 0.0 \n",
"98221 -0.047475 0.333917 0.408418 0.0 \n",
"151717 0.233664 -0.687332 -0.854427 0.0 \n",
"\n",
" State_Arizona State_Arkansas ... AlcoholDrinkers_Yes \\\n",
"6432 1.0 0.0 ... 1.0 \n",
"61767 0.0 0.0 ... 1.0 \n",
"102005 0.0 0.0 ... 1.0 \n",
"183791 0.0 0.0 ... 0.0 \n",
"230656 0.0 0.0 ... 0.0 \n",
"... ... ... ... ... \n",
"93877 0.0 0.0 ... 0.0 \n",
"117856 0.0 0.0 ... 0.0 \n",
"41922 0.0 0.0 ... 1.0 \n",
"98221 0.0 0.0 ... 1.0 \n",
"151717 0.0 0.0 ... 1.0 \n",
"\n",
" HIVTesting_Yes FluVaxLast12_Yes PneumoVaxEver_Yes \\\n",
"6432 1.0 1.0 0.0 \n",
"61767 1.0 0.0 0.0 \n",
"102005 0.0 0.0 0.0 \n",
"183791 1.0 0.0 0.0 \n",
"230656 0.0 0.0 0.0 \n",
"... ... ... ... \n",
"93877 0.0 0.0 0.0 \n",
"117856 0.0 0.0 0.0 \n",
"41922 0.0 1.0 0.0 \n",
"98221 1.0 0.0 1.0 \n",
"151717 1.0 1.0 0.0 \n",
"\n",
" TetanusLast10Tdap_Yes, received Tdap \\\n",
"6432 0.0 \n",
"61767 1.0 \n",
"102005 1.0 \n",
"183791 0.0 \n",
"230656 0.0 \n",
"... ... \n",
"93877 1.0 \n",
"117856 0.0 \n",
"41922 0.0 \n",
"98221 1.0 \n",
"151717 1.0 \n",
"\n",
" TetanusLast10Tdap_Yes, received tetanus shot but not sure what type \\\n",
"6432 0.0 \n",
"61767 0.0 \n",
"102005 0.0 \n",
"183791 1.0 \n",
"230656 0.0 \n",
"... ... \n",
"93877 0.0 \n",
"117856 1.0 \n",
"41922 1.0 \n",
"98221 0.0 \n",
"151717 0.0 \n",
"\n",
" TetanusLast10Tdap_Yes, received tetanus shot, but not Tdap \\\n",
"6432 0.0 \n",
"61767 0.0 \n",
"102005 0.0 \n",
"183791 0.0 \n",
"230656 0.0 \n",
"... ... \n",
"93877 0.0 \n",
"117856 0.0 \n",
"41922 0.0 \n",
"98221 0.0 \n",
"151717 0.0 \n",
"\n",
" HighRiskLastYear_Yes \\\n",
"6432 0.0 \n",
"61767 0.0 \n",
"102005 0.0 \n",
"183791 0.0 \n",
"230656 0.0 \n",
"... ... \n",
"93877 0.0 \n",
"117856 0.0 \n",
"41922 0.0 \n",
"98221 0.0 \n",
"151717 0.0 \n",
"\n",
" CovidPos_Tested positive using home test without a health professional \\\n",
"6432 0.0 \n",
"61767 0.0 \n",
"102005 0.0 \n",
"183791 0.0 \n",
"230656 0.0 \n",
"... ... \n",
"93877 0.0 \n",
"117856 0.0 \n",
"41922 0.0 \n",
"98221 0.0 \n",
"151717 0.0 \n",
"\n",
" CovidPos_Yes \n",
"6432 1.0 \n",
"61767 0.0 \n",
"102005 1.0 \n",
"183791 0.0 \n",
"230656 0.0 \n",
"... ... \n",
"93877 1.0 \n",
"117856 1.0 \n",
"41922 0.0 \n",
"98221 0.0 \n",
"151717 1.0 \n",
"\n",
"[196817 rows x 109 columns]"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"columns_to_drop = ['AgeCategory', 'Sex']\n",
"num_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop and df[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop and df[column].dtype == \"object\"\n",
"]\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")\n",
"\n",
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формируем набор моделей"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=9)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=9\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=9,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучаем модели и тестируем их"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict, average=None)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict, average=None)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Матрица неточностей"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7QAAAQ9CAYAAABp3wEwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxN6eMH8M+9t31XqaS0SEgIGdPYiezrfLPLPgzGMtbvWIqxjUGWwQySMUaYsX0xGSQzg2EsGUuyyxISlaLt3vP7w6/DdetUVLfl8369zmum8zznOc895NNzes5zZIIgCCAiIiIiIiIqZeTa7gARERERERHR++CAloiIiIiIiEolDmiJiIiIiIioVOKAloiIiIiIiEolDmiJiIiIiIioVOKAloiIiIiIiEolDmiJiIiIiIioVOKAloiIiIiIiEolDmiJiIiIiIioVOKAlqgMCA0NhUwmw507d4qk/Tt37kAmkyE0NLRQ2ouMjIRMJkNkZGShtEdERFRWBAYGQiaT5auuTCZDYGBg0XaIqITjgJaIiszq1asLbRBMRERERPQuHW13gIhKPicnJ7x69Qq6uroFOm716tWwtrbGoEGD1PY3a9YMr169gp6eXiH2koiIqPSbMWMGpk2bpu1uEJUaHNASUZ5kMhkMDAwKrT25XF6o7REREZUFqampMDY2ho4Of0Qnyi9OOSYqo1avXo1atWpBX18f9vb2GD16NBITEzXqfffdd3B1dYWhoSE++ugj/Pnnn2jRogVatGgh1snpGdpHjx5h8ODBcHBwgL6+PipVqoSuXbuKz/E6Ozvj8uXLOHbsGGQyGWQymdhmbs/Qnjp1Ch06dECFChVgbGyMOnXqYPny5YV7YYiIiEqA7Gdlr1y5gr59+6JChQpo0qRJjs/QpqenY8KECahYsSJMTU3RpUsX3L9/P8d2IyMj4e3tDQMDA1StWhXff/99rs/l/vTTT2jQoAEMDQ1haWmJ3r174969e0XyeYmKCm//EJVBgYGBCAoKgq+vL0aNGoWYmBisWbMG//zzD44fPy5OHV6zZg3GjBmDpk2bYsKECbhz5w66deuGChUqwMHBQfIcPXv2xOXLlzF27Fg4OzvjyZMnOHToEGJjY+Hs7Izg4GCMHTsWJiYm+OqrrwAAtra2ubZ36NAhdOrUCZUqVcK4ceNgZ2eH6Oho7Nu3D+PGjSu8i0NERFSC/Oc//0G1atUwf/58CIKAJ0+eaNQZNmwYfvrpJ/Tt2xeffPIJIiIi0LFjR41658+fR7t27VCpUiUEBQVBqVRizpw5qFixokbdefPmYebMmfD398ewYcMQHx+PlStXolmzZjh//jwsLCyK4uMSFT6BiEq9jRs3CgCE27dvC0+ePBH09PSEtm3bCkqlUqyzatUqAYAQEhIiCIIgpKenC1ZWVkLDhg2FzMxMsV5oaKgAQGjevLm47/bt2wIAYePGjYIgCMLz588FAMLixYsl+1WrVi21drIdPXpUACAcPXpUEARByMrKElxcXAQnJyfh+fPnanVVKlX+LwQREVEpMXv2bAGA0KdPnxz3Z4uKihIACJ9//rlavb59+woAhNmzZ4v7OnfuLBgZGQkPHjwQ912/fl3Q0dFRa/POnTuCQqEQ5s2bp9bmxYsXBR0dHY39RCUZpxwTlTGHDx9GRkYGxo8fD7n8zbf48OHDYWZmhv379wMAzpw5g4SEBAwfPlztWZ1+/fqhQoUKkucwNDSEnp4eIiMj8fz58w/u8/nz53H79m2MHz9e445wfl9dQEREVBqNHDlSsvzAgQMAgC+++EJt//jx49W+ViqVOHz4MLp16wZ7e3txv5ubG9q3b69Wd+fOnVCpVPD398fTp0/Fzc7ODtWqVcPRo0c/4BMRFS9OOSYqY+7evQsAqF69utp+PT09uLq6iuXZ/3Vzc1Orp6OjA2dnZ8lz6OvrY9GiRfjyyy9ha2uLjz/+GJ06dcLAgQNhZ2dX4D7fvHkTAODp6VngY4mIiEozFxcXyfK7d+9CLpejatWqavvfzfknT57g1atXGrkOaGb99evXIQgCqlWrluM5C/pWAyJt4oCWiN7L+PHj0blzZ+zevRsHDx7EzJkzsWDBAkRERKBevXra7h4REVGpYGhoWOznVKlUkMlk+O2336BQKDTKTUxMir1PRO+LU46JyhgnJycAQExMjNr+jIwM3L59WyzP/u+NGzfU6mVlZYkrFeelatWq+PLLL/H777/j0qVLyMjIwJIlS8Ty/E4Xzr7rfOnSpXzVJyIiKi+cnJygUqnE2UzZ3s15GxsbGBgYaOQ6oJn1VatWhSAIcHFxga+vr8b28ccfF/4HISoiHNASlTG+vr7Q09PDihUrIAiCuH/Dhg1ISkoSV0X09vaGlZUV1q1bh6ysLLHeli1b8nwu9uXLl0hLS1PbV7VqVZiamiI9PV3cZ2xsnOOrgt5Vv359uLi4IDg4WKP+25+BiIiovMl+/nXFihVq+4ODg9W+VigU8PX1xe7du/Hw4UNx/40bN/Dbb7+p1e3RowcUCgWCgoI0clYQBCQkJBTiJyAqWpxyTFTGVKxYEdOnT0dQUBDatWuHLl26ICYmBqtXr0bDhg3Rv39/AK+fqQ0MDMTYsWPRqlUr+Pv7486dOwgNDUXVqlUlf7t67do1tG7dGv7+/vDw8ICOjg527dqFx48fo3fv3mK9Bg0aYM2aNfj666/h5uYGGxsbtGrVSqM9uVyONWvWoHPnzvDy8sLgwYNRqVIlXL16FZcvX8bBgwcL/0IRERGVAl5eXujTpw9Wr16NpKQkfPLJJzhy5EiOv4kNDAzE77//jsaNG2PUqFFQKpVYtWoVPD09ERUVJdarWrUqvv76a0yfPl18ZZ+pqSlu376NXbt2YcSIEZg0aVIxfkqi98cBLVEZFBgYiIoVK2LVqlWYMGECLC0tMWLECMyfP19toYcxY8ZAEAQsWbIEkyZNQt26dbF371588cUXMDAwyLV9R0dH9OnTB0eOHMHmzZuho6ODGjVqYPv27ejZs6dYb9asWbh79y6++eYbvHjxAs2bN89xQAsAfn5+OHr0KIKCgrBkyRKoVCpUrVoVw4cPL7wLQ0REVAqFhISgYsWK2LJlC3bv3o1WrVph//79cHR0VKvXoEED/Pbbb5g0aRJmzpwJR0dHzJkzB9HR0bh69apa3WnTpsHd3R3Lli1DUFAQgNf53rZtW3Tp0qXYPhvRh5IJnM9HRG9RqVSoWLEievTogXXr1mm7O0RERPSBunXrhsuXL+P69eva7gpRoeMztETlWFpamsazMz/++COePXuGFi1aaKdTRERE9N5evXql9vX169dx4MAB5jqVWfwNLVE5FhkZiQkTJuA///kPrKyscO7cOWzYsAE1a9bE2bNnoaenp+0uEhERUQFUqlQJgwYNEt89v2bNGqSnp+P8+fO5vneWqDTjM7RE5ZizszMcHR2xYsUKPHv2DJaWlhg4cCAWLlzIwSwREVEp1K5dO2zduhWPHj2Cvr4+fHx8MH/+fA5mqczib2iJiIiIiIioVOIztERERERERFQqcUBLREREREREpRKfoaVSR6VS4eHDhzA1NYVMJtN2d4iKlSAIePHiBezt7SGXF+49ybS0NGRkZORZT09PT/I9xURU/jCbqTxjNmsXB7RU6jx8+FDjReJE5c29e/fg4OBQaO2lpaXBxckEj54o86xrZ2eH27dvl9vgJCJNzGYiZrO2cEBLpY6pqSkA4O45Z5iZcNa8NnR3r63tLpRbWcjEXzggfh8UloyMDDx6osSNM44wM839+yr5hQpu3veQkZFRLkOTiHLGbNY+ZrP2MJu1iwNaKnWypzKZmcglv7mp6OjIdLXdhfLr/9elL6opfSamMpiY5t62CpxKSESamM3ax2zWImazVnFAS0REokxBiUyJt7llCqpi7A0RERExm6VxQEtERCIVBKiQe2hKlREREVHhYzZL44CWiIhEKghQMjSJiIhKDGazNA5oiYhIlCmokCmRi+V9WhMREVFxYzZL44CWiIhEqv/fpMqJiIio+DCbpXFAS0REImUe05qkyoiIiKj
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from matplotlib import pyplot as plt\n",
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[f\"No {TARGET_COLUMN_NAME_CLASSIFICATION}\", TARGET_COLUMN_NAME_CLASSIFICATION]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Precision_train</th>\n",
" <th>Precision_test</th>\n",
" <th>Recall_train</th>\n",
" <th>Recall_test</th>\n",
" <th>Accuracy_train</th>\n",
" <th>Accuracy_test</th>\n",
" <th>F1_train</th>\n",
" <th>F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>logistic</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ridge</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>decision_tree</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>knn</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.999907</td>\n",
" <td>1.0</td>\n",
" <td>0.999995</td>\n",
" <td>1.0</td>\n",
" <td>[0.9999973128320332, 0.9999534775529193]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>naive_bayes</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>gradient_boosting</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>random_forest</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mlp</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Precision_train Precision_test Recall_train Recall_test \\\n",
"logistic 1.0 1.0 1.000000 1.0 \n",
"ridge 1.0 1.0 1.000000 1.0 \n",
"decision_tree 1.0 1.0 1.000000 1.0 \n",
"knn 1.0 1.0 0.999907 1.0 \n",
"naive_bayes 1.0 1.0 1.000000 1.0 \n",
"gradient_boosting 1.0 1.0 1.000000 1.0 \n",
"random_forest 1.0 1.0 1.000000 1.0 \n",
"mlp 1.0 1.0 1.000000 1.0 \n",
"\n",
" Accuracy_train Accuracy_test \\\n",
"logistic 1.000000 1.0 \n",
"ridge 1.000000 1.0 \n",
"decision_tree 1.000000 1.0 \n",
"knn 0.999995 1.0 \n",
"naive_bayes 1.000000 1.0 \n",
"gradient_boosting 1.000000 1.0 \n",
"random_forest 1.000000 1.0 \n",
"mlp 1.000000 1.0 \n",
"\n",
" F1_train F1_test \n",
"logistic [1.0, 1.0] [1.0, 1.0] \n",
"ridge [1.0, 1.0] [1.0, 1.0] \n",
"decision_tree [1.0, 1.0] [1.0, 1.0] \n",
"knn [0.9999973128320332, 0.9999534775529193] [1.0, 1.0] \n",
"naive_bayes [1.0, 1.0] [1.0, 1.0] \n",
"gradient_boosting [1.0, 1.0] [1.0, 1.0] \n",
"random_forest [1.0, 1.0] [1.0, 1.0] \n",
"mlp [1.0, 1.0] [1.0, 1.0] "
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Accuracy_test</th>\n",
" <th>F1_test</th>\n",
" <th>ROC_AUC_test</th>\n",
" <th>Cohen_kappa_test</th>\n",
" <th>MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>logistic</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ridge</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>decision_tree</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>knn</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>naive_bayes</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>gradient_boosting</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>random_forest</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mlp</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Accuracy_test F1_test ROC_AUC_test Cohen_kappa_test \\\n",
"logistic 1.0 [1.0, 1.0] 1.0 1.0 \n",
"ridge 1.0 [1.0, 1.0] 1.0 1.0 \n",
"decision_tree 1.0 [1.0, 1.0] 1.0 1.0 \n",
"knn 1.0 [1.0, 1.0] 1.0 1.0 \n",
"naive_bayes 1.0 [1.0, 1.0] 1.0 1.0 \n",
"gradient_boosting 1.0 [1.0, 1.0] 1.0 1.0 \n",
"random_forest 1.0 [1.0, 1.0] 1.0 1.0 \n",
"mlp 1.0 [1.0, 1.0] 1.0 1.0 \n",
"\n",
" MCC_test \n",
"logistic 1.0 \n",
"ridge 1.0 \n",
"decision_tree 1.0 \n",
"knn 1.0 \n",
"naive_bayes 1.0 \n",
"gradient_boosting 1.0 \n",
"random_forest 1.0 \n",
"mlp 1.0 "
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Лучшая модель"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Находим ошибки"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 0'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>State</th>\n",
" <th>Predicted</th>\n",
" <th>Sex</th>\n",
" <th>GeneralHealth</th>\n",
" <th>PhysicalHealthDays</th>\n",
" <th>MentalHealthDays</th>\n",
" <th>LastCheckupTime</th>\n",
" <th>PhysicalActivities</th>\n",
" <th>SleepHours</th>\n",
" <th>RemovedTeeth</th>\n",
" <th>...</th>\n",
" <th>HeightInMeters</th>\n",
" <th>WeightInKilograms</th>\n",
" <th>BMI</th>\n",
" <th>AlcoholDrinkers</th>\n",
" <th>HIVTesting</th>\n",
" <th>FluVaxLast12</th>\n",
" <th>PneumoVaxEver</th>\n",
" <th>TetanusLast10Tdap</th>\n",
" <th>HighRiskLastYear</th>\n",
" <th>CovidPos</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"<p>0 rows × 41 columns</p>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [State, Predicted, Sex, GeneralHealth, PhysicalHealthDays, MentalHealthDays, LastCheckupTime, PhysicalActivities, SleepHours, RemovedTeeth, HadHeartAttack, HadAngina, HadStroke, HadAsthma, HadSkinCancer, HadCOPD, HadDepressiveDisorder, HadKidneyDisease, HadArthritis, HadDiabetes, DeafOrHardOfHearing, BlindOrVisionDifficulty, DifficultyConcentrating, DifficultyWalking, DifficultyDressingBathing, DifficultyErrands, SmokerStatus, ECigaretteUsage, ChestScan, RaceEthnicityCategory, AgeCategory, HeightInMeters, WeightInKilograms, BMI, AlcoholDrinkers, HIVTesting, FluVaxLast12, PneumoVaxEver, TetanusLast10Tdap, HighRiskLastYear, CovidPos]\n",
"Index: []\n",
"\n",
"[0 rows x 41 columns]"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_new_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[TARGET_COLUMN_NAME_CLASSIFICATION] != y_new_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_new_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Пример использования модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>State</th>\n",
" <th>Sex</th>\n",
" <th>GeneralHealth</th>\n",
" <th>PhysicalHealthDays</th>\n",
" <th>MentalHealthDays</th>\n",
" <th>LastCheckupTime</th>\n",
" <th>PhysicalActivities</th>\n",
" <th>SleepHours</th>\n",
" <th>RemovedTeeth</th>\n",
" <th>HadHeartAttack</th>\n",
" <th>...</th>\n",
" <th>HeightInMeters</th>\n",
" <th>WeightInKilograms</th>\n",
" <th>BMI</th>\n",
" <th>AlcoholDrinkers</th>\n",
" <th>HIVTesting</th>\n",
" <th>FluVaxLast12</th>\n",
" <th>PneumoVaxEver</th>\n",
" <th>TetanusLast10Tdap</th>\n",
" <th>HighRiskLastYear</th>\n",
" <th>CovidPos</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>187130</th>\n",
" <td>South Dakota</td>\n",
" <td>Male</td>\n",
" <td>Poor</td>\n",
" <td>30.0</td>\n",
" <td>30.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>No</td>\n",
" <td>4.0</td>\n",
" <td>None of them</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1.83</td>\n",
" <td>97.98</td>\n",
" <td>29.29</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1 rows × 40 columns</p>\n",
"</div>"
],
"text/plain": [
" State Sex GeneralHealth PhysicalHealthDays MentalHealthDays \\\n",
"187130 South Dakota Male Poor 30.0 30.0 \n",
"\n",
" LastCheckupTime PhysicalActivities \\\n",
"187130 Within past year (anytime less than 12 months ... No \n",
"\n",
" SleepHours RemovedTeeth HadHeartAttack ... HeightInMeters \\\n",
"187130 4.0 None of them 0 ... 1.83 \n",
"\n",
" WeightInKilograms BMI AlcoholDrinkers HIVTesting FluVaxLast12 \\\n",
"187130 97.98 29.29 Yes No No \n",
"\n",
" PneumoVaxEver TetanusLast10Tdap \\\n",
"187130 No No, did not receive any tetanus shot in the pa... \n",
"\n",
" HighRiskLastYear CovidPos \n",
"187130 No No \n",
"\n",
"[1 rows x 40 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>PhysicalHealthDays</th>\n",
" <th>MentalHealthDays</th>\n",
" <th>SleepHours</th>\n",
" <th>HadHeartAttack</th>\n",
" <th>HeightInMeters</th>\n",
" <th>WeightInKilograms</th>\n",
" <th>BMI</th>\n",
" <th>State_Alaska</th>\n",
" <th>State_Arizona</th>\n",
" <th>State_Arkansas</th>\n",
" <th>...</th>\n",
" <th>AlcoholDrinkers_Yes</th>\n",
" <th>HIVTesting_Yes</th>\n",
" <th>FluVaxLast12_Yes</th>\n",
" <th>PneumoVaxEver_Yes</th>\n",
" <th>TetanusLast10Tdap_Yes, received Tdap</th>\n",
" <th>TetanusLast10Tdap_Yes, received tetanus shot but not sure what type</th>\n",
" <th>TetanusLast10Tdap_Yes, received tetanus shot, but not Tdap</th>\n",
" <th>HighRiskLastYear_Yes</th>\n",
" <th>CovidPos_Tested positive using home test without a health professional</th>\n",
" <th>CovidPos_Yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>187130</th>\n",
" <td>3.077503</td>\n",
" <td>3.188668</td>\n",
" <td>-2.094884</td>\n",
" <td>-0.24034</td>\n",
" <td>1.170796</td>\n",
" <td>0.67449</td>\n",
" <td>0.096168</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1 rows × 109 columns</p>\n",
"</div>"
],
"text/plain": [
" PhysicalHealthDays MentalHealthDays SleepHours HadHeartAttack \\\n",
"187130 3.077503 3.188668 -2.094884 -0.24034 \n",
"\n",
" HeightInMeters WeightInKilograms BMI State_Alaska \\\n",
"187130 1.170796 0.67449 0.096168 0.0 \n",
"\n",
" State_Arizona State_Arkansas ... AlcoholDrinkers_Yes \\\n",
"187130 0.0 0.0 ... 1.0 \n",
"\n",
" HIVTesting_Yes FluVaxLast12_Yes PneumoVaxEver_Yes \\\n",
"187130 0.0 0.0 0.0 \n",
"\n",
" TetanusLast10Tdap_Yes, received Tdap \\\n",
"187130 0.0 \n",
"\n",
" TetanusLast10Tdap_Yes, received tetanus shot but not sure what type \\\n",
"187130 0.0 \n",
"\n",
" TetanusLast10Tdap_Yes, received tetanus shot, but not Tdap \\\n",
"187130 0.0 \n",
"\n",
" HighRiskLastYear_Yes \\\n",
"187130 0.0 \n",
"\n",
" CovidPos_Tested positive using home test without a health professional \\\n",
"187130 0.0 \n",
"\n",
" CovidPos_Yes \n",
"187130 0.0 \n",
"\n",
"[1 rows x 109 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: 0 (proba: [9.99540301e-01 4.59698535e-04])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 0'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"\n",
"example_id = 187130\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Создаем гиперпараметры методом поиска по сетке"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\code\\AIM-PIbd-31-Potapov-N-S\\lab_4\\.venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 10,\n",
" 'model__max_features': 'sqrt',\n",
" 'model__n_estimators': 100}"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"\n",
"optimized_model_type = 'random_forest'\n",
"random_state = 9\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 50, 100],\n",
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
" \"model__max_depth\": [5, 7, 10],\n",
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
"optimized_model = ensemble.RandomForestClassifier(\n",
" random_state=42,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"sqrt\",\n",
" n_estimators=50,\n",
")\n",
"\n",
"result = {}\n",
"\n",
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели и сама оценка данных"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Precision_train</th>\n",
" <th>Precision_test</th>\n",
" <th>Recall_train</th>\n",
" <th>Recall_test</th>\n",
" <th>Accuracy_train</th>\n",
" <th>Accuracy_test</th>\n",
" <th>F1_train</th>\n",
" <th>F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th>Name</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Old</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>New</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.304987</td>\n",
" <td>0.298846</td>\n",
" <td>0.962046</td>\n",
" <td>0.961711</td>\n",
" <td>0.467418</td>\n",
" <td>0.460172</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Precision_train Precision_test Recall_train Recall_test Accuracy_train \\\n",
"Name \n",
"Old 1.0 1.0 1.0 1.0 1.0 \n",
"New 1.0 1.0 0.304987 0.298846 0.962046 \n",
"\n",
" Accuracy_test F1_train F1_test \n",
"Name \n",
"Old 1.0 [1.0, 1.0] [1.0, 1.0] \n",
"New 0.961711 0.467418 0.460172 "
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")\n",
"\n",
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Accuracy_test</th>\n",
" <th>F1_test</th>\n",
" <th>ROC_AUC_test</th>\n",
" <th>Cohen_kappa_test</th>\n",
" <th>MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th>Name</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Old</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>New</th>\n",
" <td>0.961711</td>\n",
" <td>0.460172</td>\n",
" <td>0.999994</td>\n",
" <td>0.446257</td>\n",
" <td>0.535924</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Accuracy_test F1_test ROC_AUC_test Cohen_kappa_test MCC_test\n",
"Name \n",
"Old 1.0 [1.0, 1.0] 1.0 1.0 1.0\n",
"New 0.961711 0.460172 0.999994 0.446257 0.535924"
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA9UAAAGsCAYAAADT+IQ/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABsQklEQVR4nO3dd3gU1eLG8XeTkEYahBICgYDUSAlFMSpNkSBYKF6KlFBEQUB6+0lHBOSiFEW8FIOIBlAEBS+IIChFFBClgxRBIPQkJEASdvf3R25W1wTYJBPSvp/nmUeZPXvm7BLyzplz5ozJarVaBQAAAAAAMswppxsAAAAAAEBeRacaAAAAAIBMolMNAAAAAEAm0akGAAAAACCT6FQDAAAAAJBJdKoBAAAAAMgkOtUAAAAAAGQSnWoAAAAAADLJJacbAABAXnDr1i0lJSUZVp+rq6vc3d0Nqw8AgIwg14xDpxoAgHu4deuWypfzUvRFs2F1BgQE6OTJkwX2BAQAkHPINWPRqQYA4B6SkpIUfdGsk7vLycc763dOxV23qHzdP5SUlFQgTz4AADmLXDMWnWoAABzk4+1kyMkHAAC5AblmDDrVAAA4yGy1yGw1ph4AAHIauWYMOtUAADjIIqssyvrZhxF1AACQVeSaMRjrBwAAAAAgkxipBgDAQRZZZMQEN2NqAQAga8g1Y9CpBgDAQWarVWZr1qe4GVEHAABZRa4Zg+nfAAAAAABkEiPVAAA4iAVdAAD5CblmDDrVAAA4yCKrzJx8AADyCXLNGEz/BgAAAAAgkxipBgDAQUyTAwDkJ+SaMRipBgAAAAAgkxipBgDAQTx6BACQn5BrxqBTDQCAgyz/24yoBwCAnEauGYPp3wAAAAAAZBIj1QAAOMhs0KNHjKgDAICsIteMQacaAAAHma0pmxH1AACQ08g1YzD9GwAAAACATGKkGgAAB7GgCwAgPyHXjEGnGgAAB1lkklkmQ+oBACCnkWvGYPo3AAAAAACZxEg1AAAOslhTNiPqAQAgp5FrxmCkGgAAAACATGKkGgAAB5kNuvfMiDoAAMgqcs0YdKoBAHAQJx8AgPyEXDMG078BAAAAAMgkRqoBAHCQxWqSxWrAo0cMqAMAgKwi14xBpxoAAAcxTQ4AkJ+Qa8Zg+jcAAAAAAJnESDUAAA4yy0lmA65Hmw1oCwAAWUWuGYNONQAADrIadO+ZtYDfewYAyB3INWMw/RsAAAAAgExipBoAAAexoAsAID8h14xBpxoAAAeZrU4yWw2498xqQGMAAMgics0YTP8GAAAAACCTGKkGAMBBFplkMeB6tEUF/JI+ACBXINeMwUg1AAAAAACZxEg1AAAOYkEXAEB+Qq4Zg041AAAOMm5Bl4I9TQ4AkDuQa8Zg+jcAAAAAAJnESDUAAA5KWdAl61PcjKgDAICsIteMQacaAAAHWeQkM6ukAgDyCXLNGEz/BgAAAAAgkxipBgDAQSzoAgDIT8g1Y9CpBgDAQRY5ycI0OQBAPkGuGYPp3wAAAAAAZBIj1QAAOMhsNclszfoKp0bUAQBAVpFrxmCkGgAAAACATGKkGgAAB5kNevSIuYDfewYAyB3INWPQqQYAwEEWq5MsBqySaingq6QCAHIHcs0YTP8GACCPmTp1qkwmkwYOHGjbd+vWLfXt21f+/v7y8vJS27ZtdeHCBbv3nT59Wi1btpSnp6dKlCihYcOG6fbt23ZlNm/erDp16sjNzU0VK1ZUZGRkmuO/9957Cg4Olru7u+rXr6+ffvopOz4mAKCAyOu5RqcaAAAHpU6TM2LLrJ9//lkffPCBatasabd/0KBB+uqrr7RixQpt2bJF586dU5s2bf5qu9msli1bKikpSdu3b9fixYsVGRmpsWPH2sqcPHlSLVu2VJMmTbR3714NHDhQL730ktavX28rs2zZMg0ePFjjxo3Tnj17VKtWLYWHh+vixYuZ/kwAgJxBrhmTayartYCP1QMAcA9xcXHy9fXVB3vqysMr63dO3Yy/rVfq7FZsbKx8fHwcfl98fLzq1KmjuXPn6o033lBoaKhmzpyp2NhYFS9eXJ988oleeOEFSdLhw4dVrVo17dixQ4888oj++9//6plnntG5c+dUsmRJSdK8efM0YsQIXbp0Sa6urhoxYoTWrl2r/fv3247ZoUMHxcTEaN26dZKk+vXr66GHHtK7774rSbJYLAoKClL//v01cuTILH83AIDsR64Zm2uMVAMAkEPi4uLstsTExLuW79u3r1q2bKmmTZva7d+9e7eSk5Pt9letWlVly5bVjh07JEk7duxQjRo1bCcekhQeHq64uDgdOHDAVuafdYeHh9vqSEpK0u7du+3KODk5qWnTprYyAICCq6DmGp1qAAAcZJGTYZskBQUFydfX17ZNmTLljseOiorSnj170i0THR0tV1dX+fn52e0vWbKkoqOjbWX+fuKR+nrqa3crExcXp5s3b+ry5csym83plkmtAwCQd5BrxuQaq38DAOAgs9VJZgNWSU2t48yZM3bT5Nzc3NItf+bMGQ0YMEAbNmyQu7t7lo8PAIBErhmFkWoAAHKIj4+P3Xank4/du3fr4sWLqlOnjlxcXOTi4qItW7Zo9uzZcnFxUcmSJZWUlKSYmBi79124cEEBAQGSpICAgDSrpqb++V5lfHx85OHhoWLFisnZ2TndMql1AAAKroKaa3SqAQBwkEUmw7aMePLJJ7Vv3z7t3bvXttWrV0+dOnWy/X+hQoW0ceNG23uOHDmi06dPKywsTJIUFhamffv22a1mumHDBvn4+CgkJMRW5u91pJZJrcPV1VV169a1K2OxWLRx40ZbGQBA3kGuGZNrTP8GAMBBRk+Tc5S3t7eqV69ut69w4cLy9/e37e/Zs6cGDx6sokWLysfHR/3791dYWJgeeeQRSVKzZs0UEhKiLl266K233lJ0dLRGjx6tvn372kYSevfurXfffVfDhw9Xjx49tGnTJi1fvlxr1661HXfw4MGKiIhQvXr19PDDD2vmzJlKSEhQ9+7ds/KVAAByALlmTK7RqQYAIB9455135OTkpLZt2yoxMVHh4eGaO3eu7XVnZ2etWbNGffr0UVhYmAoXLqyIiAhNnDjRVqZ8+fJau3atBg0apFmzZqlMmTJasGCBwsPDbWXat2+vS5cuaezYsYqOjlZoaKjWrVuXZpEXAACyIi/lGs+pBgDgHlKf5/nvXY8b9jzPofW2Zvh5ngAAGIFcMxb3VAMAAAAAkElM/0aeY7FYdO7cOXl7e8tkytiiCAAKHqvVquvXryswMFBOTlm7lmyxmmSxZv33jhF1IP8g1wBkBLmW+9CpRp5z7tw5BQUF5XQzAOQxZ86cUZkyZbJUh0VOMhswycvCRDH8DbkGIDPItdyDTjXyHG9vb0nSH3uC5eNVsP8BI63WlWvkdBOQy9xWsrbqa9vvDiC3IddwN+Qa/olcy33oVCPPSZ0a5+PlJB9vTj5gz8VUKKebgNzmf8txGjGt1mJ1ksWAR48YUQfyD3INd0OuIQ1yLdehUw0AgIPMMsmsrJ/EGFEHAABZRa4Zo2BfUgAAAAAAIAsYqQYAwEFMkwMA5CfkmjHoVAMA4CCzjJniZs56UwAAyDJyzRgF+5ICAAAAAABZwEg1AAAOYpocACA/IdeMUbA/PQAAAAAAWcBINQAADjJbnWQ24Gq8EXUAAJBV5Jox6FQDAOAgq0yyGLCgi7WAP88TAJA7kGvGKNiXFAAAAAAAyAJGqgEAcBDT5AAA+Qm5Zgw61QAAOMhiNclizfoUNyPqAAAgq8g1YxTsSwoAAAAAAGQBI9UAADjILCeZDbgebUQdAABkFblmDDrVAAA4iGlyAID8hFwzRsG+pAAAAAAAQBYwUg0AgIMscpLFgOvRRtQBAEBWkWvGoFMNAICDzFaTzAZMcTOiDgAAsopcM0bBvqQAAAAAAEAWMFINAICDWNAFAJCfkGvGYKQaAAAAAIBMYqQaAAAHWa1Oslizfj3aakAdAABkFblmDDrVAAA4yCyTzDJgQRcD6gAAIKvINWMU7EsKAAAAAABkASPVAAA4yGI1ZjEWi9W
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[f\"No {TARGET_COLUMN_NAME_CLASSIFICATION}\", TARGET_COLUMN_NAME_CLASSIFICATION]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Модель хорошо классифицировала объекты, которые относятся к \"No HadHeartAttack\" и \"HadHeartAttack\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Регрессия"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Бизнес цель 2: \n",
"Предсказание среднего количества часов сна в день (SleepTime) на основе других факторов."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формируем выборки"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>State</th>\n",
" <th>Sex</th>\n",
" <th>GeneralHealth</th>\n",
" <th>PhysicalHealthDays</th>\n",
" <th>MentalHealthDays</th>\n",
" <th>LastCheckupTime</th>\n",
" <th>PhysicalActivities</th>\n",
" <th>RemovedTeeth</th>\n",
" <th>HadHeartAttack</th>\n",
" <th>HadAngina</th>\n",
" <th>...</th>\n",
" <th>HeightInMeters</th>\n",
" <th>WeightInKilograms</th>\n",
" <th>BMI</th>\n",
" <th>AlcoholDrinkers</th>\n",
" <th>HIVTesting</th>\n",
" <th>FluVaxLast12</th>\n",
" <th>PneumoVaxEver</th>\n",
" <th>TetanusLast10Tdap</th>\n",
" <th>HighRiskLastYear</th>\n",
" <th>CovidPos</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>108769</th>\n",
" <td>Minnesota</td>\n",
" <td>Male</td>\n",
" <td>Good</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>None of them</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>...</td>\n",
" <td>1.73</td>\n",
" <td>83.91</td>\n",
" <td>28.13</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes, received tetanus shot but not sure what type</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>240750</th>\n",
" <td>Guam</td>\n",
" <td>Male</td>\n",
" <td>Excellent</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past 2 years (1 year but less than 2 ye...</td>\n",
" <td>Yes</td>\n",
" <td>None of them</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>...</td>\n",
" <td>1.65</td>\n",
" <td>70.31</td>\n",
" <td>25.79</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>100329</th>\n",
" <td>Michigan</td>\n",
" <td>Female</td>\n",
" <td>Excellent</td>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>No</td>\n",
" <td>1 to 5</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>...</td>\n",
" <td>1.60</td>\n",
" <td>58.97</td>\n",
" <td>23.03</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes, received tetanus shot but not sure what type</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>132628</th>\n",
" <td>New Hampshire</td>\n",
" <td>Male</td>\n",
" <td>Good</td>\n",
" <td>4.0</td>\n",
" <td>6.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>1 to 5</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>...</td>\n",
" <td>1.70</td>\n",
" <td>68.04</td>\n",
" <td>23.49</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>72101</th>\n",
" <td>Kansas</td>\n",
" <td>Male</td>\n",
" <td>Very good</td>\n",
" <td>0.0</td>\n",
" <td>2.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>None of them</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>...</td>\n",
" <td>1.83</td>\n",
" <td>99.79</td>\n",
" <td>29.84</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes, received Tdap</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>119879</th>\n",
" <td>Missouri</td>\n",
" <td>Female</td>\n",
" <td>Excellent</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>None of them</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>...</td>\n",
" <td>1.78</td>\n",
" <td>61.23</td>\n",
" <td>19.37</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes, received tetanus shot but not sure what type</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>103694</th>\n",
" <td>Michigan</td>\n",
" <td>Female</td>\n",
" <td>Good</td>\n",
" <td>10.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>1 to 5</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>...</td>\n",
" <td>1.63</td>\n",
" <td>74.84</td>\n",
" <td>28.32</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>131932</th>\n",
" <td>Nevada</td>\n",
" <td>Female</td>\n",
" <td>Good</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>1 to 5</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>...</td>\n",
" <td>1.70</td>\n",
" <td>90.72</td>\n",
" <td>31.32</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>146867</th>\n",
" <td>New York</td>\n",
" <td>Female</td>\n",
" <td>Very good</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>1 to 5</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>...</td>\n",
" <td>1.68</td>\n",
" <td>77.11</td>\n",
" <td>27.44</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>121958</th>\n",
" <td>Montana</td>\n",
" <td>Female</td>\n",
" <td>Good</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>All</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>...</td>\n",
" <td>1.65</td>\n",
" <td>98.88</td>\n",
" <td>36.28</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes, received tetanus shot but not sure what type</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>196817 rows × 39 columns</p>\n",
"</div>"
],
"text/plain": [
" State Sex GeneralHealth PhysicalHealthDays \\\n",
"108769 Minnesota Male Good 0.0 \n",
"240750 Guam Male Excellent 0.0 \n",
"100329 Michigan Female Excellent 3.0 \n",
"132628 New Hampshire Male Good 4.0 \n",
"72101 Kansas Male Very good 0.0 \n",
"... ... ... ... ... \n",
"119879 Missouri Female Excellent 0.0 \n",
"103694 Michigan Female Good 10.0 \n",
"131932 Nevada Female Good 0.0 \n",
"146867 New York Female Very good 0.0 \n",
"121958 Montana Female Good 1.0 \n",
"\n",
" MentalHealthDays LastCheckupTime \\\n",
"108769 0.0 Within past year (anytime less than 12 months ... \n",
"240750 0.0 Within past 2 years (1 year but less than 2 ye... \n",
"100329 0.0 Within past year (anytime less than 12 months ... \n",
"132628 6.0 Within past year (anytime less than 12 months ... \n",
"72101 2.0 Within past year (anytime less than 12 months ... \n",
"... ... ... \n",
"119879 0.0 Within past year (anytime less than 12 months ... \n",
"103694 0.0 Within past year (anytime less than 12 months ... \n",
"131932 0.0 Within past year (anytime less than 12 months ... \n",
"146867 0.0 Within past year (anytime less than 12 months ... \n",
"121958 0.0 Within past year (anytime less than 12 months ... \n",
"\n",
" PhysicalActivities RemovedTeeth HadHeartAttack HadAngina ... \\\n",
"108769 Yes None of them No No ... \n",
"240750 Yes None of them No Yes ... \n",
"100329 No 1 to 5 No Yes ... \n",
"132628 Yes 1 to 5 No Yes ... \n",
"72101 Yes None of them No No ... \n",
"... ... ... ... ... ... \n",
"119879 Yes None of them No No ... \n",
"103694 Yes 1 to 5 No No ... \n",
"131932 Yes 1 to 5 No No ... \n",
"146867 Yes 1 to 5 No No ... \n",
"121958 Yes All No No ... \n",
"\n",
" HeightInMeters WeightInKilograms BMI AlcoholDrinkers HIVTesting \\\n",
"108769 1.73 83.91 28.13 No Yes \n",
"240750 1.65 70.31 25.79 Yes No \n",
"100329 1.60 58.97 23.03 No No \n",
"132628 1.70 68.04 23.49 Yes No \n",
"72101 1.83 99.79 29.84 Yes No \n",
"... ... ... ... ... ... \n",
"119879 1.78 61.23 19.37 Yes Yes \n",
"103694 1.63 74.84 28.32 Yes No \n",
"131932 1.70 90.72 31.32 No No \n",
"146867 1.68 77.11 27.44 Yes No \n",
"121958 1.65 98.88 36.28 Yes Yes \n",
"\n",
" FluVaxLast12 PneumoVaxEver \\\n",
"108769 Yes Yes \n",
"240750 No No \n",
"100329 Yes Yes \n",
"132628 Yes No \n",
"72101 Yes Yes \n",
"... ... ... \n",
"119879 Yes No \n",
"103694 Yes Yes \n",
"131932 No No \n",
"146867 Yes No \n",
"121958 Yes Yes \n",
"\n",
" TetanusLast10Tdap HighRiskLastYear \\\n",
"108769 Yes, received tetanus shot but not sure what type Yes \n",
"240750 No, did not receive any tetanus shot in the pa... No \n",
"100329 Yes, received tetanus shot but not sure what type No \n",
"132628 No, did not receive any tetanus shot in the pa... No \n",
"72101 Yes, received Tdap No \n",
"... ... ... \n",
"119879 Yes, received tetanus shot but not sure what type No \n",
"103694 No, did not receive any tetanus shot in the pa... No \n",
"131932 No, did not receive any tetanus shot in the pa... No \n",
"146867 No, did not receive any tetanus shot in the pa... No \n",
"121958 Yes, received tetanus shot but not sure what type No \n",
"\n",
" CovidPos \n",
"108769 Yes \n",
"240750 Yes \n",
"100329 No \n",
"132628 No \n",
"72101 No \n",
"... ... \n",
"119879 No \n",
"103694 No \n",
"131932 No \n",
"146867 Yes \n",
"121958 No \n",
"\n",
"[196817 rows x 39 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>SleepHours</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>108769</th>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>240750</th>\n",
" <td>7.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>100329</th>\n",
" <td>9.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>132628</th>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>72101</th>\n",
" <td>7.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>119879</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>103694</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>131932</th>\n",
" <td>7.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>146867</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>121958</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>196817 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" SleepHours\n",
"108769 6.0\n",
"240750 7.0\n",
"100329 9.0\n",
"132628 6.0\n",
"72101 7.0\n",
"... ...\n",
"119879 8.0\n",
"103694 8.0\n",
"131932 7.0\n",
"146867 8.0\n",
"121958 8.0\n",
"\n",
"[196817 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>State</th>\n",
" <th>Sex</th>\n",
" <th>GeneralHealth</th>\n",
" <th>PhysicalHealthDays</th>\n",
" <th>MentalHealthDays</th>\n",
" <th>LastCheckupTime</th>\n",
" <th>PhysicalActivities</th>\n",
" <th>RemovedTeeth</th>\n",
" <th>HadHeartAttack</th>\n",
" <th>HadAngina</th>\n",
" <th>...</th>\n",
" <th>HeightInMeters</th>\n",
" <th>WeightInKilograms</th>\n",
" <th>BMI</th>\n",
" <th>AlcoholDrinkers</th>\n",
" <th>HIVTesting</th>\n",
" <th>FluVaxLast12</th>\n",
" <th>PneumoVaxEver</th>\n",
" <th>TetanusLast10Tdap</th>\n",
" <th>HighRiskLastYear</th>\n",
" <th>CovidPos</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>194767</th>\n",
" <td>Texas</td>\n",
" <td>Female</td>\n",
" <td>Good</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>None of them</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>...</td>\n",
" <td>1.68</td>\n",
" <td>113.40</td>\n",
" <td>40.35</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>231923</th>\n",
" <td>Wisconsin</td>\n",
" <td>Female</td>\n",
" <td>Good</td>\n",
" <td>2.0</td>\n",
" <td>5.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>1 to 5</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>...</td>\n",
" <td>1.73</td>\n",
" <td>104.33</td>\n",
" <td>34.97</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>52815</th>\n",
" <td>Idaho</td>\n",
" <td>Male</td>\n",
" <td>Poor</td>\n",
" <td>7.0</td>\n",
" <td>10.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>1 to 5</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>...</td>\n",
" <td>1.73</td>\n",
" <td>104.33</td>\n",
" <td>34.97</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes, received tetanus shot but not sure what type</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>65909</th>\n",
" <td>Iowa</td>\n",
" <td>Female</td>\n",
" <td>Good</td>\n",
" <td>20.0</td>\n",
" <td>10.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>No</td>\n",
" <td>All</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>...</td>\n",
" <td>1.68</td>\n",
" <td>127.01</td>\n",
" <td>45.19</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>No, did not receive any tetanus shot in the pa...</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>184154</th>\n",
" <td>South Dakota</td>\n",
" <td>Female</td>\n",
" <td>Excellent</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>None of them</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>...</td>\n",
" <td>1.60</td>\n",
" <td>49.90</td>\n",
" <td>19.49</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes, received Tdap</td>\n",
" <td>No</td>\n",
" <td>Tested positive using home test without a heal...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>57503</th>\n",
" <td>Indiana</td>\n",
" <td>Female</td>\n",
" <td>Fair</td>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>1 to 5</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>...</td>\n",
" <td>1.63</td>\n",
" <td>97.52</td>\n",
" <td>36.90</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes, received Tdap</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>47420</th>\n",
" <td>Hawaii</td>\n",
" <td>Female</td>\n",
" <td>Fair</td>\n",
" <td>30.0</td>\n",
" <td>5.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>None of them</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>...</td>\n",
" <td>1.70</td>\n",
" <td>77.56</td>\n",
" <td>26.78</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes, received tetanus shot but not sure what type</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>186088</th>\n",
" <td>South Dakota</td>\n",
" <td>Female</td>\n",
" <td>Good</td>\n",
" <td>15.0</td>\n",
" <td>15.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>1 to 5</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>...</td>\n",
" <td>1.73</td>\n",
" <td>54.88</td>\n",
" <td>18.40</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes, received tetanus shot but not sure what type</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11687</th>\n",
" <td>Arkansas</td>\n",
" <td>Male</td>\n",
" <td>Excellent</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>None of them</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>...</td>\n",
" <td>1.78</td>\n",
" <td>88.45</td>\n",
" <td>27.98</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes, received tetanus shot but not sure what type</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" </tr>\n",
" <tr>\n",
" <th>200835</th>\n",
" <td>Utah</td>\n",
" <td>Male</td>\n",
" <td>Very good</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>Within past year (anytime less than 12 months ...</td>\n",
" <td>Yes</td>\n",
" <td>None of them</td>\n",
" <td>No</td>\n",
" <td>No</td>\n",
" <td>...</td>\n",
" <td>1.91</td>\n",
" <td>118.39</td>\n",
" <td>32.62</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" <td>Yes, received Tdap</td>\n",
" <td>No</td>\n",
" <td>Yes</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>49205 rows × 39 columns</p>\n",
"</div>"
],
"text/plain": [
" State Sex GeneralHealth PhysicalHealthDays \\\n",
"194767 Texas Female Good 0.0 \n",
"231923 Wisconsin Female Good 2.0 \n",
"52815 Idaho Male Poor 7.0 \n",
"65909 Iowa Female Good 20.0 \n",
"184154 South Dakota Female Excellent 0.0 \n",
"... ... ... ... ... \n",
"57503 Indiana Female Fair 3.0 \n",
"47420 Hawaii Female Fair 30.0 \n",
"186088 South Dakota Female Good 15.0 \n",
"11687 Arkansas Male Excellent 0.0 \n",
"200835 Utah Male Very good 0.0 \n",
"\n",
" MentalHealthDays LastCheckupTime \\\n",
"194767 0.0 Within past year (anytime less than 12 months ... \n",
"231923 5.0 Within past year (anytime less than 12 months ... \n",
"52815 10.0 Within past year (anytime less than 12 months ... \n",
"65909 10.0 Within past year (anytime less than 12 months ... \n",
"184154 0.0 Within past year (anytime less than 12 months ... \n",
"... ... ... \n",
"57503 0.0 Within past year (anytime less than 12 months ... \n",
"47420 5.0 Within past year (anytime less than 12 months ... \n",
"186088 15.0 Within past year (anytime less than 12 months ... \n",
"11687 0.0 Within past year (anytime less than 12 months ... \n",
"200835 0.0 Within past year (anytime less than 12 months ... \n",
"\n",
" PhysicalActivities RemovedTeeth HadHeartAttack HadAngina ... \\\n",
"194767 Yes None of them No No ... \n",
"231923 Yes 1 to 5 No No ... \n",
"52815 Yes 1 to 5 No Yes ... \n",
"65909 No All Yes No ... \n",
"184154 Yes None of them No No ... \n",
"... ... ... ... ... ... \n",
"57503 Yes 1 to 5 Yes No ... \n",
"47420 Yes None of them No No ... \n",
"186088 Yes 1 to 5 No No ... \n",
"11687 Yes None of them No No ... \n",
"200835 Yes None of them No No ... \n",
"\n",
" HeightInMeters WeightInKilograms BMI AlcoholDrinkers HIVTesting \\\n",
"194767 1.68 113.40 40.35 No No \n",
"231923 1.73 104.33 34.97 Yes Yes \n",
"52815 1.73 104.33 34.97 No No \n",
"65909 1.68 127.01 45.19 No No \n",
"184154 1.60 49.90 19.49 Yes No \n",
"... ... ... ... ... ... \n",
"57503 1.63 97.52 36.90 Yes Yes \n",
"47420 1.70 77.56 26.78 No Yes \n",
"186088 1.73 54.88 18.40 Yes No \n",
"11687 1.78 88.45 27.98 Yes Yes \n",
"200835 1.91 118.39 32.62 No Yes \n",
"\n",
" FluVaxLast12 PneumoVaxEver \\\n",
"194767 No No \n",
"231923 No Yes \n",
"52815 Yes Yes \n",
"65909 No No \n",
"184154 Yes No \n",
"... ... ... \n",
"57503 Yes Yes \n",
"47420 Yes No \n",
"186088 Yes No \n",
"11687 Yes No \n",
"200835 No Yes \n",
"\n",
" TetanusLast10Tdap HighRiskLastYear \\\n",
"194767 No, did not receive any tetanus shot in the pa... No \n",
"231923 No, did not receive any tetanus shot in the pa... No \n",
"52815 Yes, received tetanus shot but not sure what type No \n",
"65909 No, did not receive any tetanus shot in the pa... No \n",
"184154 Yes, received Tdap No \n",
"... ... ... \n",
"57503 Yes, received Tdap No \n",
"47420 Yes, received tetanus shot but not sure what type No \n",
"186088 Yes, received tetanus shot but not sure what type No \n",
"11687 Yes, received tetanus shot but not sure what type No \n",
"200835 Yes, received Tdap No \n",
"\n",
" CovidPos \n",
"194767 Yes \n",
"231923 No \n",
"52815 No \n",
"65909 Yes \n",
"184154 Tested positive using home test without a heal... \n",
"... ... \n",
"57503 No \n",
"47420 No \n",
"186088 Yes \n",
"11687 No \n",
"200835 Yes \n",
"\n",
"[49205 rows x 39 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>SleepHours</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>194767</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>231923</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>52815</th>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>65909</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>184154</th>\n",
" <td>7.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>57503</th>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>47420</th>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>186088</th>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11687</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>200835</th>\n",
" <td>8.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>49205 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" SleepHours\n",
"194767 8.0\n",
"231923 8.0\n",
"52815 6.0\n",
"65909 8.0\n",
"184154 7.0\n",
"... ...\n",
"57503 6.0\n",
"47420 6.0\n",
"186088 6.0\n",
"11687 8.0\n",
"200835 8.0\n",
"\n",
"[49205 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"df = pd.read_csv(\"csv\\\\heart_2022_no_nans.csv\")\n",
"\n",
"TARGET_COLUMN_NAME_REGRESSION = \"SleepHours\"\n",
"\n",
"def split_into_train_test(\n",
" df_input: DataFrame,\n",
" target_colname: str,\n",
" frac_train: float = 0.8,\n",
" random_state: int = None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if not (0 < frac_train < 1):\n",
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
" \n",
" if target_colname not in df_input.columns:\n",
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
" \n",
" X = df_input.drop(columns=[target_colname])\n",
" y = df_input[[target_colname]]\n",
"\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y,\n",
" test_size=(1.0 - frac_train),\n",
" random_state=random_state\n",
" )\n",
" return X_train, X_test, y_train, y_test\n",
"\n",
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
" df, \n",
" target_colname=TARGET_COLUMN_NAME_REGRESSION, \n",
" frac_train=0.8, \n",
" random_state=42\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [],
"source": [
"def get_filtered_columns(df: DataFrame, no_numeric=False, no_text=False) -> list[str]:\n",
" \"\"\"\n",
" Возвращает список колонок по фильтру\n",
" \"\"\"\n",
" w = []\n",
" for column in df.columns:\n",
" if no_numeric and pd.api.types.is_numeric_dtype(df[column]):\n",
" continue\n",
" if no_text and not pd.api.types.is_numeric_dtype(df[column]):\n",
" continue\n",
" w.append(column)\n",
" return w"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выполним one-hot encoding, чтобы избавиться от категориальных признаков"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>PhysicalHealthDays</th>\n",
" <th>MentalHealthDays</th>\n",
" <th>HeightInMeters</th>\n",
" <th>WeightInKilograms</th>\n",
" <th>BMI</th>\n",
" <th>State_Alaska</th>\n",
" <th>State_Arizona</th>\n",
" <th>State_Arkansas</th>\n",
" <th>State_California</th>\n",
" <th>State_Colorado</th>\n",
" <th>...</th>\n",
" <th>AlcoholDrinkers_Yes</th>\n",
" <th>HIVTesting_Yes</th>\n",
" <th>FluVaxLast12_Yes</th>\n",
" <th>PneumoVaxEver_Yes</th>\n",
" <th>TetanusLast10Tdap_Yes, received Tdap</th>\n",
" <th>TetanusLast10Tdap_Yes, received tetanus shot but not sure what type</th>\n",
" <th>TetanusLast10Tdap_Yes, received tetanus shot, but not Tdap</th>\n",
" <th>HighRiskLastYear_Yes</th>\n",
" <th>CovidPos_Tested positive using home test without a health professional</th>\n",
" <th>CovidPos_Yes</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>108769</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.73</td>\n",
" <td>83.91</td>\n",
" <td>28.13</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>240750</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.65</td>\n",
" <td>70.31</td>\n",
" <td>25.79</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>100329</th>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>1.60</td>\n",
" <td>58.97</td>\n",
" <td>23.03</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>132628</th>\n",
" <td>4.0</td>\n",
" <td>6.0</td>\n",
" <td>1.70</td>\n",
" <td>68.04</td>\n",
" <td>23.49</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>72101</th>\n",
" <td>0.0</td>\n",
" <td>2.0</td>\n",
" <td>1.83</td>\n",
" <td>99.79</td>\n",
" <td>29.84</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>119879</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.78</td>\n",
" <td>61.23</td>\n",
" <td>19.37</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>103694</th>\n",
" <td>10.0</td>\n",
" <td>0.0</td>\n",
" <td>1.63</td>\n",
" <td>74.84</td>\n",
" <td>28.32</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>131932</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.70</td>\n",
" <td>90.72</td>\n",
" <td>31.32</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>146867</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.68</td>\n",
" <td>77.11</td>\n",
" <td>27.44</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>121958</th>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.65</td>\n",
" <td>98.88</td>\n",
" <td>36.28</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>...</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>196817 rows × 121 columns</p>\n",
"</div>"
],
"text/plain": [
" PhysicalHealthDays MentalHealthDays HeightInMeters \\\n",
"108769 0.0 0.0 1.73 \n",
"240750 0.0 0.0 1.65 \n",
"100329 3.0 0.0 1.60 \n",
"132628 4.0 6.0 1.70 \n",
"72101 0.0 2.0 1.83 \n",
"... ... ... ... \n",
"119879 0.0 0.0 1.78 \n",
"103694 10.0 0.0 1.63 \n",
"131932 0.0 0.0 1.70 \n",
"146867 0.0 0.0 1.68 \n",
"121958 1.0 0.0 1.65 \n",
"\n",
" WeightInKilograms BMI State_Alaska State_Arizona State_Arkansas \\\n",
"108769 83.91 28.13 False False False \n",
"240750 70.31 25.79 False False False \n",
"100329 58.97 23.03 False False False \n",
"132628 68.04 23.49 False False False \n",
"72101 99.79 29.84 False False False \n",
"... ... ... ... ... ... \n",
"119879 61.23 19.37 False False False \n",
"103694 74.84 28.32 False False False \n",
"131932 90.72 31.32 False False False \n",
"146867 77.11 27.44 False False False \n",
"121958 98.88 36.28 False False False \n",
"\n",
" State_California State_Colorado ... AlcoholDrinkers_Yes \\\n",
"108769 False False ... False \n",
"240750 False False ... True \n",
"100329 False False ... False \n",
"132628 False False ... True \n",
"72101 False False ... True \n",
"... ... ... ... ... \n",
"119879 False False ... True \n",
"103694 False False ... True \n",
"131932 False False ... False \n",
"146867 False False ... True \n",
"121958 False False ... True \n",
"\n",
" HIVTesting_Yes FluVaxLast12_Yes PneumoVaxEver_Yes \\\n",
"108769 True True True \n",
"240750 False False False \n",
"100329 False True True \n",
"132628 False True False \n",
"72101 False True True \n",
"... ... ... ... \n",
"119879 True True False \n",
"103694 False True True \n",
"131932 False False False \n",
"146867 False True False \n",
"121958 True True True \n",
"\n",
" TetanusLast10Tdap_Yes, received Tdap \\\n",
"108769 False \n",
"240750 False \n",
"100329 False \n",
"132628 False \n",
"72101 True \n",
"... ... \n",
"119879 False \n",
"103694 False \n",
"131932 False \n",
"146867 False \n",
"121958 False \n",
"\n",
" TetanusLast10Tdap_Yes, received tetanus shot but not sure what type \\\n",
"108769 True \n",
"240750 False \n",
"100329 True \n",
"132628 False \n",
"72101 False \n",
"... ... \n",
"119879 True \n",
"103694 False \n",
"131932 False \n",
"146867 False \n",
"121958 True \n",
"\n",
" TetanusLast10Tdap_Yes, received tetanus shot, but not Tdap \\\n",
"108769 False \n",
"240750 False \n",
"100329 False \n",
"132628 False \n",
"72101 False \n",
"... ... \n",
"119879 False \n",
"103694 False \n",
"131932 False \n",
"146867 False \n",
"121958 False \n",
"\n",
" HighRiskLastYear_Yes \\\n",
"108769 True \n",
"240750 False \n",
"100329 False \n",
"132628 False \n",
"72101 False \n",
"... ... \n",
"119879 False \n",
"103694 False \n",
"131932 False \n",
"146867 False \n",
"121958 False \n",
"\n",
" CovidPos_Tested positive using home test without a health professional \\\n",
"108769 False \n",
"240750 False \n",
"100329 False \n",
"132628 False \n",
"72101 False \n",
"... ... \n",
"119879 False \n",
"103694 False \n",
"131932 False \n",
"146867 False \n",
"121958 False \n",
"\n",
" CovidPos_Yes \n",
"108769 True \n",
"240750 True \n",
"100329 False \n",
"132628 False \n",
"72101 False \n",
"... ... \n",
"119879 False \n",
"103694 False \n",
"131932 False \n",
"146867 True \n",
"121958 False \n",
"\n",
"[196817 rows x 121 columns]"
]
},
"execution_count": 87,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cat_features = get_filtered_columns(df, no_numeric=True)\n",
"\n",
"X_test = pd.get_dummies(X_test, columns=cat_features, drop_first=True)\n",
"X_train = pd.get_dummies(X_train, columns=cat_features, drop_first=True)\n",
"\n",
"X_test\n",
"X_train"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Определение перечня алгоритмов решения задачи регрессии"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: linear\n",
"Model: linear_poly\n",
"Model: linear_interact\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import math\n",
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"\n",
"\n",
"models = {\n",
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPRegressor(\n",
" activation=\"tanh\",\n",
" hidden_layer_sizes=(3),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}\n",
"\n",
"for model_name in models.keys():\n",
" print(f\"Model: {model_name}\")\n",
"\n",
" fitted_model = models[model_name][\"model\"].fit(\n",
" X_train.values, y_train.values.ravel()\n",
" )\n",
" y_train_pred = fitted_model.predict(X_train.values)\n",
" y_test_pred = fitted_model.predict(X_test.values)\n",
" models[model_name][\"fitted\"] = fitted_model\n",
" models[model_name][\"train_preds\"] = y_train_pred\n",
" models[model_name][\"preds\"] = y_test_pred\n",
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_train, y_train_pred)\n",
" )\n",
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выводим результаты оценки"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_562d9_row0_col0 {\n",
" background-color: #9dd93b;\n",
" color: #000000;\n",
"}\n",
"#T_562d9_row0_col1, #T_562d9_row1_col1, #T_562d9_row2_col1, #T_562d9_row7_col0 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_562d9_row0_col2, #T_562d9_row1_col2, #T_562d9_row2_col2 {\n",
" background-color: #5002a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_562d9_row0_col3, #T_562d9_row1_col3, #T_562d9_row2_col3, #T_562d9_row7_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_562d9_row1_col0, #T_562d9_row2_col0 {\n",
" background-color: #a0da39;\n",
" color: #000000;\n",
"}\n",
"#T_562d9_row3_col0 {\n",
" background-color: #98d83e;\n",
" color: #000000;\n",
"}\n",
"#T_562d9_row3_col1 {\n",
" background-color: #23888e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_562d9_row3_col2, #T_562d9_row7_col3 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_562d9_row3_col3 {\n",
" background-color: #d35171;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_562d9_row4_col0 {\n",
" background-color: #48c16e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_562d9_row4_col1, #T_562d9_row5_col1, #T_562d9_row6_col1 {\n",
" background-color: #20928c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_562d9_row4_col2 {\n",
" background-color: #6c00a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_562d9_row4_col3, #T_562d9_row5_col3 {\n",
" background-color: #ca457a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_562d9_row5_col0 {\n",
" background-color: #4ac16d;\n",
" color: #000000;\n",
"}\n",
"#T_562d9_row5_col2 {\n",
" background-color: #6e00a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_562d9_row6_col0, #T_562d9_row7_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_562d9_row6_col2 {\n",
" background-color: #6001a6;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_562d9_row6_col3 {\n",
" background-color: #c9447a;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_562d9\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_562d9_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
" <th id=\"T_562d9_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
" <th id=\"T_562d9_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
" <th id=\"T_562d9_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_562d9_level0_row0\" class=\"row_heading level0 row0\" >mlp</th>\n",
" <td id=\"T_562d9_row0_col0\" class=\"data row0 col0\" >1.401571</td>\n",
" <td id=\"T_562d9_row0_col1\" class=\"data row0 col1\" >1.401556</td>\n",
" <td id=\"T_562d9_row0_col2\" class=\"data row0 col2\" >1.001832</td>\n",
" <td id=\"T_562d9_row0_col3\" class=\"data row0 col3\" >0.049273</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_562d9_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_562d9_row1_col0\" class=\"data row1 col0\" >1.403185</td>\n",
" <td id=\"T_562d9_row1_col1\" class=\"data row1 col1\" >1.401859</td>\n",
" <td id=\"T_562d9_row1_col2\" class=\"data row1 col2\" >1.001885</td>\n",
" <td id=\"T_562d9_row1_col3\" class=\"data row1 col3\" >0.048861</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_562d9_level0_row2\" class=\"row_heading level0 row2\" >linear</th>\n",
" <td id=\"T_562d9_row2_col0\" class=\"data row2 col0\" >1.403184</td>\n",
" <td id=\"T_562d9_row2_col1\" class=\"data row2 col1\" >1.401860</td>\n",
" <td id=\"T_562d9_row2_col2\" class=\"data row2 col2\" >1.001898</td>\n",
" <td id=\"T_562d9_row2_col3\" class=\"data row2 col3\" >0.048860</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_562d9_level0_row3\" class=\"row_heading level0 row3\" >random_forest</th>\n",
" <td id=\"T_562d9_row3_col0\" class=\"data row3 col0\" >1.400360</td>\n",
" <td id=\"T_562d9_row3_col1\" class=\"data row3 col1\" >1.408185</td>\n",
" <td id=\"T_562d9_row3_col2\" class=\"data row3 col2\" >1.001482</td>\n",
" <td id=\"T_562d9_row3_col3\" class=\"data row3 col3\" >0.040258</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_562d9_level0_row4\" class=\"row_heading level0 row4\" >linear_poly</th>\n",
" <td id=\"T_562d9_row4_col0\" class=\"data row4 col0\" >1.365912</td>\n",
" <td id=\"T_562d9_row4_col1\" class=\"data row4 col1\" >1.416653</td>\n",
" <td id=\"T_562d9_row4_col2\" class=\"data row4 col2\" >1.008370</td>\n",
" <td id=\"T_562d9_row4_col3\" class=\"data row4 col3\" >0.028680</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_562d9_level0_row5\" class=\"row_heading level0 row5\" >linear_interact</th>\n",
" <td id=\"T_562d9_row5_col0\" class=\"data row5 col0\" >1.366066</td>\n",
" <td id=\"T_562d9_row5_col1\" class=\"data row5 col1\" >1.417008</td>\n",
" <td id=\"T_562d9_row5_col2\" class=\"data row5 col2\" >1.008543</td>\n",
" <td id=\"T_562d9_row5_col3\" class=\"data row5 col3\" >0.028193</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_562d9_level0_row6\" class=\"row_heading level0 row6\" >decision_tree</th>\n",
" <td id=\"T_562d9_row6_col0\" class=\"data row6 col0\" >1.406026</td>\n",
" <td id=\"T_562d9_row6_col1\" class=\"data row6 col1\" >1.417750</td>\n",
" <td id=\"T_562d9_row6_col2\" class=\"data row6 col2\" >1.005576</td>\n",
" <td id=\"T_562d9_row6_col3\" class=\"data row6 col3\" >0.027175</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_562d9_level0_row7\" class=\"row_heading level0 row7\" >knn</th>\n",
" <td id=\"T_562d9_row7_col0\" class=\"data row7 col0\" >1.296492</td>\n",
" <td id=\"T_562d9_row7_col1\" class=\"data row7 col1\" >1.493316</td>\n",
" <td id=\"T_562d9_row7_col2\" class=\"data row7 col2\" >1.041495</td>\n",
" <td id=\"T_562d9_row7_col3\" class=\"data row7 col3\" >-0.079292</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x13672dd7fe0>"
]
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выводим лучшую модель"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'mlp'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подбираем гиперпараметры методом поиска по сетке"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 8 candidates, totalling 24 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\code\\AIM-PIbd-31-Potapov-N-S\\lab_4\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие параметры: {'max_depth': 10, 'min_samples_split': 5, 'n_estimators': 100}\n",
"Лучший результат (MSE): 1.9866610870680514\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestRegressor\n",
"\n",
"\n",
"X = df[get_filtered_columns(df, no_numeric=True)]\n",
"y = df[TARGET_COLUMN_NAME_REGRESSION] \n",
"\n",
"model = RandomForestRegressor() \n",
"\n",
"param_grid = {\n",
" 'n_estimators': [50, 100], \n",
" 'max_depth': [10, 20], \n",
" 'min_samples_split': [5, 10] \n",
"}\n",
"\n",
"grid_search = GridSearchCV(estimator=model, param_grid=param_grid,\n",
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
"\n",
"grid_search.fit(X_train, y_train)\n",
"\n",
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
"print(\"Лучший результат (MSE):\", -grid_search.best_score_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучаем модель с новыми гиперпараметрами и сравниваем новых данных со старыми"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 8 candidates, totalling 24 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\code\\AIM-PIbd-31-Potapov-N-S\\lab_4\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\code\\AIM-PIbd-31-Potapov-N-S\\lab_4\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\code\\AIM-PIbd-31-Potapov-N-S\\lab_4\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\code\\AIM-PIbd-31-Potapov-N-S\\lab_4\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\code\\AIM-PIbd-31-Potapov-N-S\\lab_4\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\code\\AIM-PIbd-31-Potapov-N-S\\lab_4\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Старые параметры: {'max_depth': 10, 'min_samples_split': 10, 'n_estimators': 100}\n",
"Лучший результат (MSE) на старых параметрах: 1.9867639342405718\n",
"\n",
"Новые параметры: {'max_depth': 10, 'min_samples_split': 5, 'n_estimators': 100}\n",
"Лучший результат (MSE) на новых параметрах: 1.990467882679972\n",
"Среднеквадратическая ошибка (MSE) на тестовых данных: 1.975249119855746\n",
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 1.4054355623278307\n"
]
}
],
"source": [
"# Old data\n",
"\n",
"old_param_grid = param_grid\n",
"old_grid_search = grid_search\n",
"old_grid_search.fit(X_train, y_train)\n",
"\n",
"old_best_params = old_grid_search.best_params_\n",
"old_best_mse = -old_grid_search.best_score_ \n",
"\n",
"# New data\n",
"\n",
"new_param_grid = {\n",
" 'n_estimators': [100],\n",
" 'max_depth': [10],\n",
" 'min_samples_split': [5]\n",
" }\n",
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
" param_grid=new_param_grid,\n",
" scoring='neg_mean_squared_error', cv=2)\n",
"\n",
"new_grid_search.fit(X_train, y_train)\n",
"\n",
"new_best_params = new_grid_search.best_params_\n",
"new_best_mse = -new_grid_search.best_score_\n",
"\n",
"new_best_model = RandomForestRegressor(**new_best_params)\n",
"new_best_model.fit(X_train, y_train)\n",
"\n",
"old_best_model = RandomForestRegressor(**old_best_params)\n",
"old_best_model.fit(X_train, y_train)\n",
"\n",
"y_new_pred = new_best_model.predict(X_test)\n",
"y_old_pred = old_best_model.predict(X_test)\n",
"\n",
"mse = metrics.mean_squared_error(y_test, y_new_pred)\n",
"rmse = np.sqrt(mse)\n",
"\n",
"print(\"Старые параметры:\", old_best_params)\n",
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
"print(\"\\nНовые параметры:\", new_best_params)\n",
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Визуализация данных"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABRoAAAK9CAYAAABLm9DzAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeXhTZfr/8U/SlqYtSRDC3kZACm0FR8AFdFB0VGBEbWF0HB0VV1wAHUdFHQVxQ8VlREe+bgMuuEPRwXEbFRRkE1ltKrtpR7YWSUMhpW3y+6O/Zhqa0rRpm4S+X9fVC3LOc85zP0tO0rtnMfh8Pp8AAAAAAAAAIAzGSAcAAAAAAAAAIPaRaAQAAAAAAAAQNhKNAAAAAAAAAMJGohEAAAAAAABA2Eg0AgAAAAAAAAgbiUYAAAAAAAAAYSPRCAAAAAAAACBsJBoBAAAAAAAAhI1EIwAAAAAAAICwkWgEAAAAgBi1YMECrVmzxv96/vz5+vHHHyMXEACgVSPRCABAI23ZskXjxo1Tr169ZDKZZLFYdPrpp+vZZ5/VwYMHIx0eAKAVWL9+vW699VZt2rRJy5Yt04033ii32x3psAAArZTB5/P5Ih0EAACx5uOPP9bFF1+sxMREXXnllerXr58OHTqkxYsXa+7cuRo7dqxeeumlSIcJADjK7dmzR6eddpo2b94sSRo9erTmzp0b4agAAK0ViUYAABpo27ZtOuGEE5SamqqvvvpKXbt2DVi/efNmffzxx7r11lsjFCEAoDUpKyvThg0blJycrMzMzEiHAwBoxbh0GgCABnriiSe0f/9+vfrqq7WSjJLUu3fvgCSjwWDQ+PHjNWfOHPXt21cmk0mDBg3SN998E7Ddzz//rJtvvll9+/ZVUlKSOnTooIsvvljbt28PKDd79mwZDAb/T3Jysvr3769XXnkloNzYsWPVtm3bWvF98MEHMhgMWrhwYcDy5cuXa8SIEbJarUpOTtaZZ56pJUuWBJR54IEHZDAYVFRUFLD8+++/l8Fg0OzZswPq79GjR0C5goICJSUlyWAw1GrXJ598oqFDhyolJUVms1nnn39+SPcZO7w/Dv954IEHasWfn5+vSy65RBaLRR06dNCtt94qj8dTa99vvvmmBg0apKSkJLVv316XXnqpCgoKgsZRV/2H97PH49EDDzygPn36yGQyqWvXrho9erS2bNkiSdq+fXutvnS73Ro0aJB69uypHTt2+Jc/+eSTOu2009ShQwclJSVp0KBB+uCDDwLqKy4u1siRI5WamqrExER17dpVl19+uX7++eeAcqHsq7qd48ePr7V81KhRAeNd3Y4nn3yyVtl+/fpp2LBh/tcLFy6UwWAIWl+1w+fTlClTZDQa9eWXXwaUu+GGG9SmTRutXbu2zn1Vt6Pm3JCk6dOny2AwBMTWFNsfacyr++lIP2PHjpX0v7le873j9Xp1wgknBH3/hfr+HzZsmPr161er7JNPPlmrvh49emjUqFF19kv1WFbv3+FwKCkpSVdeeWVAucWLFysuLk6TJk2qc19S1Xs2KytLbdu2lcVi0eDBgzV//vyAMg2J/8MPP9T555+vbt26KTExUccdd5weeughVVZWBmwbbHyD9b8U2rGroeNx+BxauXKlfz4EizMxMVGDBg1SZmZmg+YxAABNLT7SAQAAEGv+9a9/qVevXjrttNNC3mbRokV69913NXHiRCUmJuqFF17QiBEjtGLFCv8vyCtXrtR3332nSy+9VKmpqdq+fbtmzpypYcOGKS8vT8nJyQH7fOaZZ2Sz2VRSUqJ//vOfuv7669WjRw+dc845DW7TV199pZEjR2rQoEH+BM6sWbN09tln69tvv9Upp5zS4H0GM3ny5KAJvTfeeENXXXWVhg8frscff1wHDhzQzJkz9dvf/larV6+ulbAM5sEHH1TPnj39r/fv36+bbropaNlLLrlEPXr00LRp07Rs2TLNmDFDv/76q15//XV/mUceeUT333+/LrnkEl133XXas2ePnnvuOZ1xxhlavXq12rVrV2u/5557rj+hsnLlSs2YMSNgfWVlpUaNGqUvv/xSl156qW699Va53W598cUX2rBhg4477rha+ywvL9eYMWPkdDq1ZMmSgOT2s88+qwsvvFCXX365Dh06pHfeeUcXX3yxFixYoPPPP1+SdOjQIZnNZt16663q0KGDtmzZoueee07r1q3T+vXrG7SvaHLffffpX//6l6699lqtX79eZrNZn332mV5++WU99NBD+s1vftOg/e3bt0/Tpk1rdDx1bV/fmJ9zzjl64403/OXnzZun3NzcgGXB5kW1N954I2Aco01mZqYeeugh3XnnnfrDH/6gCy+8UKWlpRo7dqwyMjL04IMPHnH70tJS5eTkqEePHjp48KBmz56tMWPGaOnSpY06Ls2ePVtt27bV7bffrrZt2+qrr77S5MmTVVJSounTpzd4f01x7ApFfQnZauHOYwAAwuYDAAAhc7lcPkm+iy66KORtJPkk+b7//nv/sp9//tlnMpl8OTk5/mUHDhyote3SpUt9knyvv/66f9msWbN8knzbtm3zL9u4caNPku+JJ57wL7vqqqt8KSkptfb5/vvv+yT5vv76a5/P5/N5vV5fenq6b/jw4T6v1xsQT8+ePX3nnnuuf9mUKVN8knx79uwJ2OfKlSt9knyzZs0KqP/YY4/1v96wYYPPaDT6Ro4cGRC/2+32tWvXznf99dcH7HPnzp0+q9Vaa/nhqvtj5cqVAcv37Nnjk+SbMmVKrfgvvPDCgLI333yzT5Jv7dq1Pp/P59u+fbsvLi7O98gjjwSUW79+vS8+Pr7W8kOHDvkk+caPH+9fdng/+3w+3z//+U+fJN/TTz9dqx3Vfb9t2zZ/X3q9Xt/ll1/uS05O9i1fvrzWNofPmUOHDvn69evnO/vss2uVremJJ57wSfIVFRU1eF+SfLfcckutfZ5//vkB413djunTp9cqe/zxx/vOPPNM/+uvv/7aJ8n3/vvv1xnz4fPJ56sajzZt2viuu+4636+//urr3r2776STTvKVl5fXuZ+a7ag5N+666y5fp06dfIMGDQqILdztQxnzmqrnaDCHv/c9Ho/Pbrf731OHv/9Cef/7fD7fmWee6Tv++ONrlZ0+fXqtY82xxx7rO//884PG5/P9byxr7r+ystL329/+1te5c2dfUVGR75ZbbvHFx8fXes+GYvfu3T5JvieffLJR8Qc7zo4bN86XnJzs83g8/mUGg8E3efLkgHKH939Djl0NHY+ac+jf//63T5JvxIgRteZGuPMYAICmxqXTAAA0QElJiSTJbDY3aLshQ4Zo0KBB/td2u10XXXSRPvvsM/8le0lJSf715eXlKi4uVu/evdWuXTv98MMPtfb566+/qqioSFu3btUzzzyjuLg4nXnmmbXKFRUVBfwc/jTSNWvWaNOmTbrssstUXFzsL1daWqrf/e53+uabb+T1egO22bt3b8A+XS5XvX1wzz33aODAgbr44osDln/xxRfat2+f/vSnPwXsMy4uTqeeeqq+/vrrevfdULfcckvA6wkTJkiS/v3vf0uqOqvM6/XqkksuCYipS5cuSk9PrxVT9VmaJpPpiPXOnTtXNpvNX19Nh18SKUl33nmn5syZo/feey/o2Vs158yvv/4ql8uloUOHBp0vbrdbu3fv1tKlS/X222/r+OOPV/v27Ru1L4/HU2telZeXB23zgQMHapU9/DLVmjEWFRVp3759Qdcfrl+/fpo6dapeeeUVDR8+XEVFRXrttdcUH9+wi3b++9//6rnnntP9998f9PLWcLZv6Jg3xD/+8Q8VFxdrypQpdZap7/1frbKyslbZAwcOBC1bXl6uoqIiFRcXq6Kiot44jUajZs+erf3792vkyJF64YUXdM899+ikk04KqZ3V9W3ZskWPPfaYjEajTj/99EbFX3OeV8+3oUOH6sCBA8rPz/ev69SpkwoLC48YV2OOXaGORzWfz6d77rlHY8aM0amnnnrEsuHOYwAAmgKXTgMA0AAWi0WS6v3l8HDp6em1lvXp00cHDhzQnj171KVLFx0
"text/plain": [
"<Figure size 1600x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(16, 8))\n",
"plt.scatter(range(len(y_test)), y_test, label=\"Истинные значения\", color=\"black\", alpha=0.5)\n",
"plt.scatter(range(len(y_test)), y_new_pred, label=\"Предсказанные (новые параметры)\", color=\"blue\", alpha=0.5)\n",
"plt.scatter(range(len(y_test)), y_old_pred, label=\"Предсказанные (старые параметры)\", color=\"red\", alpha=0.5)\n",
"plt.xlabel(\"Выборка\")\n",
"plt.ylabel(\"Значения\")\n",
"plt.legend()\n",
"plt.title(\"Сравнение предсказанных и истинных значений\")\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}