3935 lines
320 KiB
Plaintext
Raw Normal View History

2024-11-09 12:20:09 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Начало 4-й лабораторной\n",
"#### Ближайшие объекты к Земле"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['id', 'name', 'est_diameter_min', 'est_diameter_max',\n",
" 'relative_velocity', 'miss_distance', 'orbiting_body', 'sentry_object',\n",
" 'absolute_magnitude', 'hazardous'],\n",
" dtype='object')\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>name</th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>orbiting_body</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2162635</td>\n",
" <td>162635 (2000 SS164)</td>\n",
" <td>1.198271</td>\n",
" <td>2.679415</td>\n",
" <td>13569.249224</td>\n",
" <td>5.483974e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>16.73</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2277475</td>\n",
" <td>277475 (2005 WK4)</td>\n",
" <td>0.265800</td>\n",
" <td>0.594347</td>\n",
" <td>73588.726663</td>\n",
" <td>6.143813e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>20.00</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2512244</td>\n",
" <td>512244 (2015 YE18)</td>\n",
" <td>0.722030</td>\n",
" <td>1.614507</td>\n",
" <td>114258.692129</td>\n",
" <td>4.979872e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>17.83</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3596030</td>\n",
" <td>(2012 BV13)</td>\n",
" <td>0.096506</td>\n",
" <td>0.215794</td>\n",
" <td>24764.303138</td>\n",
" <td>2.543497e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>22.20</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3667127</td>\n",
" <td>(2014 GE35)</td>\n",
" <td>0.255009</td>\n",
" <td>0.570217</td>\n",
" <td>42737.733765</td>\n",
" <td>4.627557e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>20.09</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90831</th>\n",
" <td>3763337</td>\n",
" <td>(2016 VX1)</td>\n",
" <td>0.026580</td>\n",
" <td>0.059435</td>\n",
" <td>52078.886692</td>\n",
" <td>1.230039e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>25.00</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90832</th>\n",
" <td>3837603</td>\n",
" <td>(2019 AD3)</td>\n",
" <td>0.016771</td>\n",
" <td>0.037501</td>\n",
" <td>46114.605073</td>\n",
" <td>5.432121e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>26.00</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90833</th>\n",
" <td>54017201</td>\n",
" <td>(2020 JP3)</td>\n",
" <td>0.031956</td>\n",
" <td>0.071456</td>\n",
" <td>7566.807732</td>\n",
" <td>2.840077e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>24.60</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90834</th>\n",
" <td>54115824</td>\n",
" <td>(2021 CN5)</td>\n",
" <td>0.007321</td>\n",
" <td>0.016370</td>\n",
" <td>69199.154484</td>\n",
" <td>6.869206e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>27.80</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90835</th>\n",
" <td>54205447</td>\n",
" <td>(2021 TW7)</td>\n",
" <td>0.039862</td>\n",
" <td>0.089133</td>\n",
" <td>27024.455553</td>\n",
" <td>5.977213e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>24.12</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>90836 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" id name est_diameter_min est_diameter_max \\\n",
"0 2162635 162635 (2000 SS164) 1.198271 2.679415 \n",
"1 2277475 277475 (2005 WK4) 0.265800 0.594347 \n",
"2 2512244 512244 (2015 YE18) 0.722030 1.614507 \n",
"3 3596030 (2012 BV13) 0.096506 0.215794 \n",
"4 3667127 (2014 GE35) 0.255009 0.570217 \n",
"... ... ... ... ... \n",
"90831 3763337 (2016 VX1) 0.026580 0.059435 \n",
"90832 3837603 (2019 AD3) 0.016771 0.037501 \n",
"90833 54017201 (2020 JP3) 0.031956 0.071456 \n",
"90834 54115824 (2021 CN5) 0.007321 0.016370 \n",
"90835 54205447 (2021 TW7) 0.039862 0.089133 \n",
"\n",
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
"0 13569.249224 5.483974e+07 Earth False \n",
"1 73588.726663 6.143813e+07 Earth False \n",
"2 114258.692129 4.979872e+07 Earth False \n",
"3 24764.303138 2.543497e+07 Earth False \n",
"4 42737.733765 4.627557e+07 Earth False \n",
"... ... ... ... ... \n",
"90831 52078.886692 1.230039e+07 Earth False \n",
"90832 46114.605073 5.432121e+07 Earth False \n",
"90833 7566.807732 2.840077e+07 Earth False \n",
"90834 69199.154484 6.869206e+07 Earth False \n",
"90835 27024.455553 5.977213e+07 Earth False \n",
"\n",
" absolute_magnitude hazardous \n",
"0 16.73 False \n",
"1 20.00 True \n",
"2 17.83 False \n",
"3 22.20 False \n",
"4 20.09 True \n",
"... ... ... \n",
"90831 25.00 False \n",
"90832 26.00 False \n",
"90833 24.60 False \n",
"90834 27.80 False \n",
"90835 24.12 False \n",
"\n",
"[90836 rows x 10 columns]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"df = pd.read_csv(\".//static//csv//neo.csv\")\n",
"print(df.columns)\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Бизнес-цели:\n",
"\n",
"1. Идентификация потенциально опасных объектов\n",
"\n",
"Описание: классифицировать астероиды как потенциально опасные или безопасные (используя целевой признак \"hazardous\"). Эта задача актуальна для оценки рисков и подготовки соответствующих действий по защите Земли.\n",
"\n",
"2. Прогнозирование минимального расстояния до Земли\n",
"\n",
"Описание: предсказать минимальное расстояние до Земли для новых объектов на основе характеристик астероида (скорости, размера и других параметров). Это позволит планировать исследования и наблюдения в зависимости от опасности. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Определение достижимого уровня качества модели для первой задачи "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
"\n",
"Целевой признак -- hazardous"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>name</th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>orbiting_body</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2639</th>\n",
" <td>3634614</td>\n",
" <td>(2013 GT66)</td>\n",
" <td>0.024241</td>\n",
" <td>0.054205</td>\n",
" <td>43303.999094</td>\n",
" <td>4.814117e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>25.20</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29138</th>\n",
" <td>54143560</td>\n",
" <td>(2021 JU1)</td>\n",
" <td>0.030238</td>\n",
" <td>0.067615</td>\n",
" <td>21770.790211</td>\n",
" <td>5.646643e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>24.72</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>36927</th>\n",
" <td>3836085</td>\n",
" <td>(2018 VQ3)</td>\n",
" <td>0.201630</td>\n",
" <td>0.450858</td>\n",
" <td>109358.123029</td>\n",
" <td>6.435051e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>20.60</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61855</th>\n",
" <td>3769804</td>\n",
" <td>(2017 DJ34)</td>\n",
" <td>0.160160</td>\n",
" <td>0.358129</td>\n",
" <td>78494.609756</td>\n",
" <td>5.595780e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>21.10</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15916</th>\n",
" <td>3824978</td>\n",
" <td>(2018 KS)</td>\n",
" <td>0.006991</td>\n",
" <td>0.015633</td>\n",
" <td>19077.749486</td>\n",
" <td>3.834648e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>27.90</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29491</th>\n",
" <td>3827304</td>\n",
" <td>(2018 RR1)</td>\n",
" <td>0.002658</td>\n",
" <td>0.005943</td>\n",
" <td>19826.895880</td>\n",
" <td>3.852881e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>30.00</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18373</th>\n",
" <td>3735468</td>\n",
" <td>(2015 WY1)</td>\n",
" <td>0.103408</td>\n",
" <td>0.231228</td>\n",
" <td>82856.544926</td>\n",
" <td>7.314334e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>22.05</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25031</th>\n",
" <td>3802041</td>\n",
" <td>(2018 FE3)</td>\n",
" <td>0.009651</td>\n",
" <td>0.021579</td>\n",
" <td>34243.774201</td>\n",
" <td>4.257719e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>27.20</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35456</th>\n",
" <td>3430406</td>\n",
" <td>(2008 TR10)</td>\n",
" <td>0.221083</td>\n",
" <td>0.494356</td>\n",
" <td>19557.289783</td>\n",
" <td>2.152970e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>20.40</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14305</th>\n",
" <td>3285300</td>\n",
" <td>(2005 OG3)</td>\n",
" <td>0.298233</td>\n",
" <td>0.666868</td>\n",
" <td>20309.404706</td>\n",
" <td>1.770015e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>19.75</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>72668 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" id name est_diameter_min est_diameter_max \\\n",
"2639 3634614 (2013 GT66) 0.024241 0.054205 \n",
"29138 54143560 (2021 JU1) 0.030238 0.067615 \n",
"36927 3836085 (2018 VQ3) 0.201630 0.450858 \n",
"61855 3769804 (2017 DJ34) 0.160160 0.358129 \n",
"15916 3824978 (2018 KS) 0.006991 0.015633 \n",
"... ... ... ... ... \n",
"29491 3827304 (2018 RR1) 0.002658 0.005943 \n",
"18373 3735468 (2015 WY1) 0.103408 0.231228 \n",
"25031 3802041 (2018 FE3) 0.009651 0.021579 \n",
"35456 3430406 (2008 TR10) 0.221083 0.494356 \n",
"14305 3285300 (2005 OG3) 0.298233 0.666868 \n",
"\n",
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
"2639 43303.999094 4.814117e+07 Earth False \n",
"29138 21770.790211 5.646643e+07 Earth False \n",
"36927 109358.123029 6.435051e+07 Earth False \n",
"61855 78494.609756 5.595780e+07 Earth False \n",
"15916 19077.749486 3.834648e+07 Earth False \n",
"... ... ... ... ... \n",
"29491 19826.895880 3.852881e+07 Earth False \n",
"18373 82856.544926 7.314334e+07 Earth False \n",
"25031 34243.774201 4.257719e+07 Earth False \n",
"35456 19557.289783 2.152970e+07 Earth False \n",
"14305 20309.404706 1.770015e+07 Earth False \n",
"\n",
" absolute_magnitude hazardous \n",
"2639 25.20 False \n",
"29138 24.72 False \n",
"36927 20.60 False \n",
"61855 21.10 False \n",
"15916 27.90 False \n",
"... ... ... \n",
"29491 30.00 False \n",
"18373 22.05 False \n",
"25031 27.20 False \n",
"35456 20.40 False \n",
"14305 19.75 False \n",
"\n",
"[72668 rows x 10 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2639</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29138</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>36927</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61855</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15916</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29491</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18373</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25031</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35456</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14305</th>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>72668 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" hazardous\n",
"2639 False\n",
"29138 False\n",
"36927 False\n",
"61855 False\n",
"15916 False\n",
"... ...\n",
"29491 False\n",
"18373 False\n",
"25031 False\n",
"35456 False\n",
"14305 False\n",
"\n",
"[72668 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>name</th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>orbiting_body</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>9040</th>\n",
" <td>2474532</td>\n",
" <td>474532 (2003 VG1)</td>\n",
" <td>0.472667</td>\n",
" <td>1.056915</td>\n",
" <td>21779.237137</td>\n",
" <td>3.443050e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>18.75</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>67305</th>\n",
" <td>3774018</td>\n",
" <td>(2017 HF1)</td>\n",
" <td>0.084053</td>\n",
" <td>0.187949</td>\n",
" <td>53291.016226</td>\n",
" <td>6.862591e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>22.50</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>77741</th>\n",
" <td>54269585</td>\n",
" <td>(2022 GQ2)</td>\n",
" <td>0.018220</td>\n",
" <td>0.040742</td>\n",
" <td>43089.046433</td>\n",
" <td>2.592726e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>25.82</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>81520</th>\n",
" <td>54097970</td>\n",
" <td>(2020 XS)</td>\n",
" <td>0.152952</td>\n",
" <td>0.342011</td>\n",
" <td>93246.455599</td>\n",
" <td>4.709054e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>21.20</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>508</th>\n",
" <td>3730802</td>\n",
" <td>(2015 TT238)</td>\n",
" <td>0.031956</td>\n",
" <td>0.071456</td>\n",
" <td>37708.258544</td>\n",
" <td>4.232149e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>24.60</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28261</th>\n",
" <td>3532365</td>\n",
" <td>(2010 MH1)</td>\n",
" <td>0.139494</td>\n",
" <td>0.311918</td>\n",
" <td>37604.980238</td>\n",
" <td>7.369507e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>21.40</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1159</th>\n",
" <td>54073345</td>\n",
" <td>(2020 UE)</td>\n",
" <td>0.020728</td>\n",
" <td>0.046349</td>\n",
" <td>36720.077728</td>\n",
" <td>3.366114e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>25.54</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48095</th>\n",
" <td>3836195</td>\n",
" <td>(2018 VT7)</td>\n",
" <td>0.006991</td>\n",
" <td>0.015633</td>\n",
" <td>7616.496535</td>\n",
" <td>6.376350e+06</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>27.90</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90234</th>\n",
" <td>3752902</td>\n",
" <td>(2016 JG12)</td>\n",
" <td>0.084053</td>\n",
" <td>0.187949</td>\n",
" <td>21894.554692</td>\n",
" <td>5.736984e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>22.50</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12013</th>\n",
" <td>3445077</td>\n",
" <td>(2009 BM58)</td>\n",
" <td>0.038420</td>\n",
" <td>0.085909</td>\n",
" <td>49828.611609</td>\n",
" <td>4.305599e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>24.20</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>18168 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" id name est_diameter_min est_diameter_max \\\n",
"9040 2474532 474532 (2003 VG1) 0.472667 1.056915 \n",
"67305 3774018 (2017 HF1) 0.084053 0.187949 \n",
"77741 54269585 (2022 GQ2) 0.018220 0.040742 \n",
"81520 54097970 (2020 XS) 0.152952 0.342011 \n",
"508 3730802 (2015 TT238) 0.031956 0.071456 \n",
"... ... ... ... ... \n",
"28261 3532365 (2010 MH1) 0.139494 0.311918 \n",
"1159 54073345 (2020 UE) 0.020728 0.046349 \n",
"48095 3836195 (2018 VT7) 0.006991 0.015633 \n",
"90234 3752902 (2016 JG12) 0.084053 0.187949 \n",
"12013 3445077 (2009 BM58) 0.038420 0.085909 \n",
"\n",
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
"9040 21779.237137 3.443050e+07 Earth False \n",
"67305 53291.016226 6.862591e+07 Earth False \n",
"77741 43089.046433 2.592726e+07 Earth False \n",
"81520 93246.455599 4.709054e+07 Earth False \n",
"508 37708.258544 4.232149e+07 Earth False \n",
"... ... ... ... ... \n",
"28261 37604.980238 7.369507e+07 Earth False \n",
"1159 36720.077728 3.366114e+07 Earth False \n",
"48095 7616.496535 6.376350e+06 Earth False \n",
"90234 21894.554692 5.736984e+07 Earth False \n",
"12013 49828.611609 4.305599e+07 Earth False \n",
"\n",
" absolute_magnitude hazardous \n",
"9040 18.75 False \n",
"67305 22.50 False \n",
"77741 25.82 False \n",
"81520 21.20 False \n",
"508 24.60 False \n",
"... ... ... \n",
"28261 21.40 False \n",
"1159 25.54 False \n",
"48095 27.90 False \n",
"90234 22.50 False \n",
"12013 24.20 False \n",
"\n",
"[18168 rows x 10 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>9040</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>67305</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>77741</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>81520</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>508</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28261</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1159</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48095</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90234</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12013</th>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>18168 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" hazardous\n",
"9040 False\n",
"67305 False\n",
"77741 False\n",
"81520 False\n",
"508 False\n",
"... ...\n",
"28261 False\n",
"1159 False\n",
"48095 False\n",
"90234 False\n",
"12013 False\n",
"\n",
"[18168 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Устанавливаем случайное состояние\n",
"random_state = 42\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" X = df_input # Contains all columns.\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
" # Split original dataframe into train and temp dataframes.\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
" # Split the temp dataframe into val and test dataframes.\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"hazardous\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование конвейера для классификации данных\n",
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
"\n",
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
"\n",
"features_preprocessing -- трансформер для предобработки признаков\n",
"\n",
"features_engineering -- трансформер для конструирования признаков\n",
"\n",
"drop_columns -- трансформер для удаления колонок\n",
"\n",
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"class EarthObjectsFeatures(BaseEstimator, TransformerMixin):\n",
" def __init__(self):\n",
" pass\n",
" def fit(self, X, y=None):\n",
" return self\n",
" def transform(self, X, y=None):\n",
" X[\"Length_to_Width_Ratio\"] = X[\"x\"] / X[\"y\"]\n",
" return X\n",
" def get_feature_names_out(self, features_in):\n",
" return np.append(features_in, [\"Length_to_Width_Ratio\"], axis=0)\n",
" \n",
"\n",
"columns_to_drop = [\"name\", \"orbiting_body\"]\n",
"num_columns = [\"est_diameter_min\", \"est_diameter_max\",\n",
" \"relative_velocity\", \"miss_distance\", \"sentry_object\",\n",
" \"absolute_magnitude\", \"hazardous\"]\n",
"cat_columns = []\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"features_postprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_cat\", preprocessing_cat, [\"Cabin_type\"]),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Демонстрация работы конвейера"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" <th>id</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2639</th>\n",
" <td>-0.331616</td>\n",
" <td>-0.331616</td>\n",
" <td>-0.188160</td>\n",
" <td>0.494297</td>\n",
" <td>0.0</td>\n",
" <td>0.577785</td>\n",
" <td>-0.328347</td>\n",
" <td>3634614</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29138</th>\n",
" <td>-0.312486</td>\n",
" <td>-0.312486</td>\n",
" <td>-1.040729</td>\n",
" <td>0.866716</td>\n",
" <td>0.0</td>\n",
" <td>0.412170</td>\n",
" <td>-0.328347</td>\n",
" <td>54143560</td>\n",
" </tr>\n",
" <tr>\n",
" <th>36927</th>\n",
" <td>0.234246</td>\n",
" <td>0.234246</td>\n",
" <td>2.427134</td>\n",
" <td>1.219399</td>\n",
" <td>0.0</td>\n",
" <td>-1.009355</td>\n",
" <td>-0.328347</td>\n",
" <td>3836085</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61855</th>\n",
" <td>0.101960</td>\n",
" <td>0.101960</td>\n",
" <td>1.205148</td>\n",
" <td>0.843963</td>\n",
" <td>0.0</td>\n",
" <td>-0.836840</td>\n",
" <td>-0.328347</td>\n",
" <td>3769804</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15916</th>\n",
" <td>-0.386643</td>\n",
" <td>-0.386643</td>\n",
" <td>-1.147355</td>\n",
" <td>0.056145</td>\n",
" <td>0.0</td>\n",
" <td>1.509367</td>\n",
" <td>-0.328347</td>\n",
" <td>3824978</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29491</th>\n",
" <td>-0.400466</td>\n",
" <td>-0.400466</td>\n",
" <td>-1.117694</td>\n",
" <td>0.064301</td>\n",
" <td>0.0</td>\n",
" <td>2.233931</td>\n",
" <td>-0.328347</td>\n",
" <td>3827304</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18373</th>\n",
" <td>-0.079077</td>\n",
" <td>-0.079077</td>\n",
" <td>1.377851</td>\n",
" <td>1.612734</td>\n",
" <td>0.0</td>\n",
" <td>-0.509061</td>\n",
" <td>-0.328347</td>\n",
" <td>3735468</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25031</th>\n",
" <td>-0.378159</td>\n",
" <td>-0.378159</td>\n",
" <td>-0.546884</td>\n",
" <td>0.245400</td>\n",
" <td>0.0</td>\n",
" <td>1.267846</td>\n",
" <td>-0.328347</td>\n",
" <td>3802041</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35456</th>\n",
" <td>0.296300</td>\n",
" <td>0.296300</td>\n",
" <td>-1.128369</td>\n",
" <td>-0.696130</td>\n",
" <td>0.0</td>\n",
" <td>-1.078361</td>\n",
" <td>-0.328347</td>\n",
" <td>3430406</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14305</th>\n",
" <td>0.542404</td>\n",
" <td>0.542404</td>\n",
" <td>-1.098590</td>\n",
" <td>-0.867440</td>\n",
" <td>0.0</td>\n",
" <td>-1.302631</td>\n",
" <td>-0.328347</td>\n",
" <td>3285300</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>72668 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" est_diameter_min est_diameter_max relative_velocity miss_distance \\\n",
"2639 -0.331616 -0.331616 -0.188160 0.494297 \n",
"29138 -0.312486 -0.312486 -1.040729 0.866716 \n",
"36927 0.234246 0.234246 2.427134 1.219399 \n",
"61855 0.101960 0.101960 1.205148 0.843963 \n",
"15916 -0.386643 -0.386643 -1.147355 0.056145 \n",
"... ... ... ... ... \n",
"29491 -0.400466 -0.400466 -1.117694 0.064301 \n",
"18373 -0.079077 -0.079077 1.377851 1.612734 \n",
"25031 -0.378159 -0.378159 -0.546884 0.245400 \n",
"35456 0.296300 0.296300 -1.128369 -0.696130 \n",
"14305 0.542404 0.542404 -1.098590 -0.867440 \n",
"\n",
" sentry_object absolute_magnitude hazardous id \n",
"2639 0.0 0.577785 -0.328347 3634614 \n",
"29138 0.0 0.412170 -0.328347 54143560 \n",
"36927 0.0 -1.009355 -0.328347 3836085 \n",
"61855 0.0 -0.836840 -0.328347 3769804 \n",
"15916 0.0 1.509367 -0.328347 3824978 \n",
"... ... ... ... ... \n",
"29491 0.0 2.233931 -0.328347 3827304 \n",
"18373 0.0 -0.509061 -0.328347 3735468 \n",
"25031 0.0 1.267846 -0.328347 3802041 \n",
"35456 0.0 -1.078361 -0.328347 3430406 \n",
"14305 0.0 -1.302631 -0.328347 3285300 \n",
"\n",
"[72668 rows x 8 columns]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование набора моделей для классификации\n",
" logistic -- логистическая регрессия\n",
"\n",
"ridge -- гребневая регрессия\n",
"\n",
"decision_tree -- дерево решений\n",
"\n",
"knn -- k-ближайших соседей\n",
"\n",
"naive_bayes -- наивный Байесовский классификатор\n",
"\n",
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
"\n",
"random_forest -- метод случайного леса (набор деревьев решений)\n",
"\n",
"mlp -- многослойный персептрон (нейронная сеть)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение моделей на обучающем наборе данных и оценка на тестовом"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Сводная таблица оценок качества для использованных моделей классификации"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA3YAAAQ9CAYAAAAVld+sAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVhU1f8H8PewDPuwyZoIKIiioIll5EZKgJppmuZS7pql5fJ1LRfQ1LTc18q9NLfSn5kb7mukKC6I5AKCyaICIirbzP39QXNtYmYABQaY9+t57pNzzrn3njsJHz/3nnOuRBAEAURERERERFRtGei6A0RERERERPRymNgRERERERFVc0zsiIiIiIiIqjkmdkRERERERNUcEzsiIiIiIqJqjokdERERERFRNcfEjoiIiIiIqJpjYkdERERERFTNMbEjIiIiIiKq5pjYUY2yfv16SCQSJCYmVsjxExMTIZFIsH79+nI53rFjxyCRSHDs2LFyOR4REVFNER4eDolEUqq2EokE4eHhFdshoiqOiR1RJVixYkW5JYNERERERP9lpOsOEFUn7u7uePbsGYyNjcu034oVK1CrVi0MGDBApbxNmzZ49uwZpFJpOfaSiIio+psyZQomTZqk624QVRtM7IjKQCKRwNTUtNyOZ2BgUK7HIyIiqgmePHkCCwsLGBnxn6pEpcWhmFTjrVixAo0aNYKJiQlcXV0xYsQIZGVlFWu3fPly1K1bF2ZmZnj99ddx8uRJBAUFISgoSGyjbo5damoqBg4ciNq1a8PExAQuLi7o0qWLOM/Pw8MDsbGxOH78OCQSCSQSiXhMTXPsoqKi0LFjR9ja2sLCwgL+/v5YvHhx+X4xREREVYByLt21a9fQp08f2NraolWrVmrn2OXl5WHMmDFwcHCAlZUV3n33Xdy9e1ftcY8dO4bmzZvD1NQU9erVw3fffadx3t5PP/2EgIAAmJmZwc7ODr169UJycnKFXC9RReFtEKrRwsPDERERgeDgYHzyySeIj4/HypUrce7cOZw+fVocUrly5UqMHDkSrVu3xpgxY5CYmIiuXbvC1tYWtWvX1nqO7t27IzY2Fp999hk8PDyQnp6OyMhIJCUlwcPDA4sWLcJnn30GS0tLfPnllwAAJycnjceLjIzEO++8AxcXF4waNQrOzs6Ii4vDnj17MGrUqPL7coiIiKqQHj16wNvbG7Nnz4YgCEhPTy/WZsiQIfjpp5/Qp08fvPnmmzhy5Ag6depUrN3FixcRFhYGFxcXREREQC6XY8aMGXBwcCjWdtasWZg6dSp69uyJIUOG4P79+1i6dCnatGmDixcvwsbGpiIul6j8CUQ1yLp16wQAQkJCgpCeni5IpVIhJCREkMvlYptly5YJAIS1a9cKgiAIeXl5gr29vfDaa68JBQUFYrv169cLAIS2bduKZQkJCQIAYd26dYIgCEJmZqYAQPjmm2+09qtRo0Yqx1E6evSoAEA4evSoIAiCUFhYKHh6egru7u5CZmamSluFQlH6L4KIiKiamD59ugBA6N27t9pypZiYGAGA8Omnn6q069OnjwBAmD59uljWuXNnwdzcXPj777/Fshs3bghGRkYqx0xMTBQMDQ2FWbNmqRzzypUrgpGRUbFyoqqMQzGpxjp06BDy8/MxevRoGBg8/6s+dOhQyGQy/P777wCA8+fP4+HDhxg6dKjKWP6+ffvC1tZW6znMzMwglUpx7NgxZGZmvnSfL168iISEBIwePbrYHcLSLvlMRERUHQ0fPlxr/d69ewEAn3/+uUr56NGjVT7L5XIcOnQIXbt2haurq1ju5eWFDh06qLT99ddfoVAo0LNnTzx48EDcnJ2d4e3tjaNHj77EFRFVLg7FpBrrzp07AAAfHx+VcqlUirp164r1yv96eXmptDMyMoKHh4fWc5iYmGDu3Ln43//+BycnJ7zxxht455130K9fPzg7O5e5z7du3QIANG7cuMz7EhERVWeenp5a6+/cuQMDAwPUq1dPpfy/cT49PR3Pnj0rFteB4rH+xo0bEAQB3t7eas9Z1lWwiXSJiR3RSxo9ejQ6d+6MXbt24cCBA5g6dSrmzJmDI0eO4NVXX9V194iIiKoFMzOzSj+nQqGARCLBvn37YGhoWKze0tKy0vtE9KI4FJNqLHd3dwBAfHy8Snl+fj4SEhLEeuV/b968qdKusLBQXNmyJPXq1cP//vc/HDx4EFevXkV+fj7mz58v1pd2GKXyLuTVq1dL1Z6IiEhfuLu7Q6FQiKNblP4b5x0dHWFqalosrgPFY329evUgCAI8PT0RHBxcbHvjjTfK/0KIKggTO6qxgoODIZVKsWTJEgiCIJavWbMGjx49ElfRat68Oezt7fHDDz+gsLBQbLdp06YS5809ffoUubm5KmX16tWDlZUV8vLyxDILCwu1r1j4r2bNmsHT0xOLFi0q1v7f10BERKRvlPPjlixZolK+aNEilc+GhoYIDg7Grl27cO/ePbH85s2b2Ldvn0rbbt26wdDQEBEREcXirCAIePjwYTleAVHF4lBMqrEcHBwwefJkREREICwsDO+++y7i4+OxYsUKvPbaa/jwww8BFM25Cw8Px2effYZ27dqhZ8+eSExMxPr161GvXj2tT9v++usvtG/fHj179oSvry+MjIywc+dOpKWloVevXmK7gIAArFy5El999RW8vLzg6OiIdu3aFTuegYEBVq5cic6dO6Np06YYOHAgXFxccP36dcTGxuLAgQPl/0URERFVA02bNkXv3r2xYsUKPHr0CG+++SYOHz6s9slceHg4Dh48iJYtW+KTTz6BXC7HsmXL0LhxY8TExIjt6tWrh6+++gqTJ08WX3VkZWWFhIQE7Ny5E8OGDcO4ceMq8SqJXhwTO6rRwsPD4eDggGXLlmHMmDGws7PDsGHDMHv2bJUJ0SNHjoQgCJg/fz7GjRuHJk2aYPfu3fj8889hamqq8fhubm7o3bs3Dh8+jB9//BFGRkZo0KABtm3bhu7du4vtpk2bhjt37mDevHl4/Pgx2rZtqzaxA4DQ0FAcPXoUERERmD9/PhQKBerVq4ehQ4eW3xdDRERUDa1duxYODg7YtGkTdu3ahXbt2uH333+Hm5ubSruAgADs27cP48aNw9SpU+Hm5oYZM2YgLi4O169fV2k7adIk1K9fHwsXLkRERASAovgeEhKCd999t9KujehlSQSO7yJSS6FQwMHBAd26dcMPP/yg6+4QERHRS+ratStiY2Nx48YNXXeFqNxxjh0RgNzc3GJj6zdu3IiMjAwEBQXpplNERET0wp49e6by+caNG9i7dy/jOtVYfGJHBODYsWMYM2YMevToAXt7e1y4cAFr1qxBw4YNER0dDalUqusuEhERURm4uLhgwIAB4rtrV65ciby8PFy8eFHje+uIqjPOsSMC4OHhATc3NyxZsgQZGRmws7NDv3798PXXXzOpIyIiqobCwsLw888/IzU1FSYmJggMDMTs2bOZ1FGNxSd2RERERERE1Rzn2BEREREREVVzTOyIiIiIiIiqOc6xq6YUCgXu3bsHKysrrS/QJqqpBEHA48eP4erqCgOD8r1HlZubi/z8fK1tpFKp1nccEpF+Ynwmfabr2Azod3xmYldN3bt3r9jLOIn0UXJyMmrXrl1ux8vNzYWnuyVS0+Va2zk7OyMhIUFvgwcRqcf4TKS72Azod3xmYldNWVlZAQDuXPCAzJIjanXhvfp+uu6CXitEAU5hr/izUF7y8/ORmi7HzfNukFmp/9nKfqyAV/Nk5Ofn62XgICLNGJ91r+tHvXXdBb1VWJiH09Hf6iQ2A4zPTOyqKeXwDpmlgda/4FRxjCTGuu6CfvtnPd+KGupkaSWBpZX6YyvA4VVEpB7js+4ZGenfP+irGl3EZoDxmYkdEZEaBYIcBRreBlMgKCq5N0RERKQtNhfV63d8ZmJHRKSGAgIUUB88NJUTERFRxdEWm5X1+oyJHRGRGgoIkDOxIyIiqjK0xWZlvT5jYkdEpEaBoECBhvig70M9iIiIdEFbbFbW6zMmdkREaij+2TTVERERUeXSFptRQp0
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"hazardous\", \"safe\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"16400 - количество истинных положительных диагнозов (True Positives), где модель правильно определила объекты как \"hazardous\".\n",
"\n",
"1768 в некоторых моделях - количество ложных отрицательных диагнозов (False Negatives), где модель неправильно определила объекты, которые на самом деле принадлежат к классу \"hazardous\", но были отнесены к классу \"safe\". \n",
"\n",
"Исходя из значений True Positives и False Negatives, можно сказать, что модель имеет высокую точность при предсказании класса \"hazardous\". В принципе, уровень ложных отрицательных результатов в некоторых моделях (1768) говорит нам о том, что существует некотрое небольшое количество примеров, которые модель пропускает.\n",
"\n",
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_371be_row0_col0, #T_371be_row0_col1, #T_371be_row0_col2, #T_371be_row0_col3, #T_371be_row1_col0, #T_371be_row1_col1, #T_371be_row1_col2, #T_371be_row1_col3, #T_371be_row2_col0, #T_371be_row2_col1, #T_371be_row2_col2, #T_371be_row2_col3, #T_371be_row3_col0, #T_371be_row3_col1, #T_371be_row3_col2, #T_371be_row3_col3, #T_371be_row7_col2, #T_371be_row7_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_371be_row0_col4, #T_371be_row0_col5, #T_371be_row0_col6, #T_371be_row0_col7, #T_371be_row1_col4, #T_371be_row1_col5, #T_371be_row1_col6, #T_371be_row1_col7, #T_371be_row2_col4, #T_371be_row2_col5, #T_371be_row2_col6, #T_371be_row2_col7, #T_371be_row3_col4, #T_371be_row3_col5, #T_371be_row3_col6, #T_371be_row3_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_371be_row4_col0 {\n",
" background-color: #86d549;\n",
" color: #000000;\n",
"}\n",
"#T_371be_row4_col1 {\n",
" background-color: #77d153;\n",
" color: #000000;\n",
"}\n",
"#T_371be_row4_col2 {\n",
" background-color: #63cb5f;\n",
" color: #000000;\n",
"}\n",
"#T_371be_row4_col3 {\n",
" background-color: #4ac16d;\n",
" color: #000000;\n",
"}\n",
"#T_371be_row4_col4 {\n",
" background-color: #c03a83;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_371be_row4_col5 {\n",
" background-color: #b32c8e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_371be_row4_col6 {\n",
" background-color: #c7427c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_371be_row4_col7 {\n",
" background-color: #bd3786;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_371be_row5_col0, #T_371be_row5_col1, #T_371be_row5_col2, #T_371be_row5_col3, #T_371be_row6_col0, #T_371be_row6_col1, #T_371be_row6_col2, #T_371be_row6_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_371be_row5_col4, #T_371be_row6_col4 {\n",
" background-color: #8004a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_371be_row5_col5, #T_371be_row6_col5 {\n",
" background-color: #7d03a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_371be_row5_col6, #T_371be_row5_col7, #T_371be_row6_col6, #T_371be_row6_col7, #T_371be_row7_col4, #T_371be_row7_col5 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_371be_row7_col0 {\n",
" background-color: #25ac82;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_371be_row7_col1 {\n",
" background-color: #26ad81;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_371be_row7_col6 {\n",
" background-color: #ac2694;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_371be_row7_col7 {\n",
" background-color: #ad2793;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_371be\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_371be_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_371be_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_371be_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_371be_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_371be_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_371be_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_371be_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_371be_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_371be_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_371be_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_371be_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_371be_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_371be_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_371be_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_371be_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_371be_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_371be_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_371be_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
" <td id=\"T_371be_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_371be_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_371be_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_371be_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_371be_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_371be_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_371be_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_371be_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_371be_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
" <td id=\"T_371be_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_371be_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_371be_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_371be_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_371be_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" <td id=\"T_371be_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
" <td id=\"T_371be_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
" <td id=\"T_371be_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_371be_level0_row3\" class=\"row_heading level0 row3\" >gradient_boosting</th>\n",
" <td id=\"T_371be_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_371be_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_371be_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_371be_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_371be_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" <td id=\"T_371be_row3_col5\" class=\"data row3 col5\" >1.000000</td>\n",
" <td id=\"T_371be_row3_col6\" class=\"data row3 col6\" >1.000000</td>\n",
" <td id=\"T_371be_row3_col7\" class=\"data row3 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_371be_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
" <td id=\"T_371be_row4_col0\" class=\"data row4 col0\" >0.884596</td>\n",
" <td id=\"T_371be_row4_col1\" class=\"data row4 col1\" >0.826374</td>\n",
" <td id=\"T_371be_row4_col2\" class=\"data row4 col2\" >0.744627</td>\n",
" <td id=\"T_371be_row4_col3\" class=\"data row4 col3\" >0.638009</td>\n",
" <td id=\"T_371be_row4_col4\" class=\"data row4 col4\" >0.965693</td>\n",
" <td id=\"T_371be_row4_col5\" class=\"data row4 col5\" >0.951728</td>\n",
" <td id=\"T_371be_row4_col6\" class=\"data row4 col6\" >0.808599</td>\n",
" <td id=\"T_371be_row4_col7\" class=\"data row4 col7\" >0.720077</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_371be_level0_row5\" class=\"row_heading level0 row5\" >naive_bayes</th>\n",
" <td id=\"T_371be_row5_col0\" class=\"data row5 col0\" >0.000000</td>\n",
" <td id=\"T_371be_row5_col1\" class=\"data row5 col1\" >0.000000</td>\n",
" <td id=\"T_371be_row5_col2\" class=\"data row5 col2\" >0.000000</td>\n",
" <td id=\"T_371be_row5_col3\" class=\"data row5 col3\" >0.000000</td>\n",
" <td id=\"T_371be_row5_col4\" class=\"data row5 col4\" >0.902681</td>\n",
" <td id=\"T_371be_row5_col5\" class=\"data row5 col5\" >0.902686</td>\n",
" <td id=\"T_371be_row5_col6\" class=\"data row5 col6\" >0.000000</td>\n",
" <td id=\"T_371be_row5_col7\" class=\"data row5 col7\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_371be_level0_row6\" class=\"row_heading level0 row6\" >mlp</th>\n",
" <td id=\"T_371be_row6_col0\" class=\"data row6 col0\" >0.000000</td>\n",
" <td id=\"T_371be_row6_col1\" class=\"data row6 col1\" >0.000000</td>\n",
" <td id=\"T_371be_row6_col2\" class=\"data row6 col2\" >0.000000</td>\n",
" <td id=\"T_371be_row6_col3\" class=\"data row6 col3\" >0.000000</td>\n",
" <td id=\"T_371be_row6_col4\" class=\"data row6 col4\" >0.902681</td>\n",
" <td id=\"T_371be_row6_col5\" class=\"data row6 col5\" >0.902686</td>\n",
" <td id=\"T_371be_row6_col6\" class=\"data row6 col6\" >0.000000</td>\n",
" <td id=\"T_371be_row6_col7\" class=\"data row6 col7\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_371be_level0_row7\" class=\"row_heading level0 row7\" >ridge</th>\n",
" <td id=\"T_371be_row7_col0\" class=\"data row7 col0\" >0.415780</td>\n",
" <td id=\"T_371be_row7_col1\" class=\"data row7 col1\" >0.421253</td>\n",
" <td id=\"T_371be_row7_col2\" class=\"data row7 col2\" >1.000000</td>\n",
" <td id=\"T_371be_row7_col3\" class=\"data row7 col3\" >1.000000</td>\n",
" <td id=\"T_371be_row7_col4\" class=\"data row7 col4\" >0.863255</td>\n",
" <td id=\"T_371be_row7_col5\" class=\"data row7 col5\" >0.866303</td>\n",
" <td id=\"T_371be_row7_col6\" class=\"data row7 col6\" >0.587351</td>\n",
" <td id=\"T_371be_row7_col7\" class=\"data row7 col7\" >0.592791</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1dd731d3fe0>"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Все модели в данной выборке, а именно логистическая регрессия, ридж-регрессия, дерево решений, KNN, наивный байесовский классификатор, градиентный бустинг, случайный лес и многослойный перцептрон (MLP) демонстрируют неплохие значения по всем метрикам на обучающих и тестовых наборах данных.\n",
"\n",
"Модели Naive Bayes и MLP не так эффективны по сравнению с другими, но в некоторых метриках показывают высокие результаты. \n",
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_9ba87_row0_col0, #T_9ba87_row0_col1, #T_9ba87_row1_col0, #T_9ba87_row1_col1, #T_9ba87_row2_col0, #T_9ba87_row2_col1, #T_9ba87_row3_col0, #T_9ba87_row3_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_9ba87_row0_col2, #T_9ba87_row0_col3, #T_9ba87_row0_col4, #T_9ba87_row1_col2, #T_9ba87_row1_col3, #T_9ba87_row1_col4, #T_9ba87_row2_col2, #T_9ba87_row2_col3, #T_9ba87_row2_col4, #T_9ba87_row3_col2, #T_9ba87_row3_col3, #T_9ba87_row3_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9ba87_row4_col0, #T_9ba87_row6_col1, #T_9ba87_row7_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9ba87_row4_col1 {\n",
" background-color: #40bd72;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9ba87_row4_col2 {\n",
" background-color: #d9586a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9ba87_row4_col3, #T_9ba87_row6_col2 {\n",
" background-color: #a51f99;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9ba87_row4_col4 {\n",
" background-color: #ae2892;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9ba87_row5_col0 {\n",
" background-color: #4ac16d;\n",
" color: #000000;\n",
"}\n",
"#T_9ba87_row5_col1 {\n",
" background-color: #5cc863;\n",
" color: #000000;\n",
"}\n",
"#T_9ba87_row5_col2 {\n",
" background-color: #d14e72;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9ba87_row5_col3 {\n",
" background-color: #ba3388;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9ba87_row5_col4 {\n",
" background-color: #bb3488;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9ba87_row6_col0, #T_9ba87_row7_col0 {\n",
" background-color: #1e9d89;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9ba87_row6_col3, #T_9ba87_row6_col4, #T_9ba87_row7_col2, #T_9ba87_row7_col3, #T_9ba87_row7_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_9ba87\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_9ba87_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_9ba87_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_9ba87_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_9ba87_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_9ba87_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_9ba87_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_9ba87_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_9ba87_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_9ba87_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_9ba87_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_9ba87_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_9ba87_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
" <td id=\"T_9ba87_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_9ba87_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_9ba87_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_9ba87_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_9ba87_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_9ba87_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
" <td id=\"T_9ba87_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_9ba87_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_9ba87_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_9ba87_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_9ba87_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_9ba87_level0_row3\" class=\"row_heading level0 row3\" >gradient_boosting</th>\n",
" <td id=\"T_9ba87_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_9ba87_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_9ba87_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_9ba87_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_9ba87_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_9ba87_level0_row4\" class=\"row_heading level0 row4\" >ridge</th>\n",
" <td id=\"T_9ba87_row4_col0\" class=\"data row4 col0\" >0.866303</td>\n",
" <td id=\"T_9ba87_row4_col1\" class=\"data row4 col1\" >0.592791</td>\n",
" <td id=\"T_9ba87_row4_col2\" class=\"data row4 col2\" >0.995675</td>\n",
" <td id=\"T_9ba87_row4_col3\" class=\"data row4 col3\" >0.528180</td>\n",
" <td id=\"T_9ba87_row4_col4\" class=\"data row4 col4\" >0.599051</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_9ba87_level0_row5\" class=\"row_heading level0 row5\" >knn</th>\n",
" <td id=\"T_9ba87_row5_col0\" class=\"data row5 col0\" >0.951728</td>\n",
" <td id=\"T_9ba87_row5_col1\" class=\"data row5 col1\" >0.720077</td>\n",
" <td id=\"T_9ba87_row5_col2\" class=\"data row5 col2\" >0.953405</td>\n",
" <td id=\"T_9ba87_row5_col3\" class=\"data row5 col3\" >0.694141</td>\n",
" <td id=\"T_9ba87_row5_col4\" class=\"data row5 col4\" >0.701100</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_9ba87_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
" <td id=\"T_9ba87_row6_col0\" class=\"data row6 col0\" >0.902686</td>\n",
" <td id=\"T_9ba87_row6_col1\" class=\"data row6 col1\" >0.000000</td>\n",
" <td id=\"T_9ba87_row6_col2\" class=\"data row6 col2\" >0.766341</td>\n",
" <td id=\"T_9ba87_row6_col3\" class=\"data row6 col3\" >0.000000</td>\n",
" <td id=\"T_9ba87_row6_col4\" class=\"data row6 col4\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_9ba87_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_9ba87_row7_col0\" class=\"data row7 col0\" >0.902686</td>\n",
" <td id=\"T_9ba87_row7_col1\" class=\"data row7 col1\" >0.000000</td>\n",
" <td id=\"T_9ba87_row7_col2\" class=\"data row7 col2\" >0.500000</td>\n",
" <td id=\"T_9ba87_row7_col3\" class=\"data row7 col3\" >0.000000</td>\n",
" <td id=\"T_9ba87_row7_col4\" class=\"data row7 col4\" >0.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1dd76e7ec00>"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Схожий вывод можно сделать и для следующих метрик: Accuracy, F1, ROC AUC, Cohen's Kappa и MCC. Все модели, кроме Naive Bayes и MLP, указывают на хорошо-развитую способность к выделению классов"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 0'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>Predicted</th>\n",
" <th>name</th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>orbiting_body</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [id, Predicted, name, est_diameter_min, est_diameter_max, relative_velocity, miss_distance, orbiting_body, sentry_object, absolute_magnitude, hazardous]\n",
"Index: []"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"hazardous\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Пример использования обученной модели (конвейера) для предсказания\n"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>name</th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>orbiting_body</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>67305</th>\n",
" <td>3774018</td>\n",
" <td>(2017 HF1)</td>\n",
" <td>0.084053</td>\n",
" <td>0.187949</td>\n",
" <td>53291.016226</td>\n",
" <td>68625911.198806</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>22.5</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id name est_diameter_min est_diameter_max \\\n",
"67305 3774018 (2017 HF1) 0.084053 0.187949 \n",
"\n",
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
"67305 53291.016226 68625911.198806 Earth False \n",
"\n",
" absolute_magnitude hazardous \n",
"67305 22.5 False "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" <th>id</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>67305</th>\n",
" <td>-0.140818</td>\n",
" <td>-0.140818</td>\n",
" <td>0.207258</td>\n",
" <td>1.410653</td>\n",
" <td>0.0</td>\n",
" <td>-0.353797</td>\n",
" <td>-0.328347</td>\n",
" <td>3774018.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" est_diameter_min est_diameter_max relative_velocity miss_distance \\\n",
"67305 -0.140818 -0.140818 0.207258 1.410653 \n",
"\n",
" sentry_object absolute_magnitude hazardous id \n",
"67305 0.0 -0.353797 -0.328347 3774018.0 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: False (proba: [9.99855425e-01 1.44575476e-04])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 0'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 67305\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Подбор гиперпараметров методом поиска по сетке "
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 5,\n",
" 'model__max_features': 'sqrt',\n",
" 'model__n_estimators': 50}"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"optimized_model_type = \"random_forest\"\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 50, 100],\n",
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
" \"model__max_depth\": [5, 7, 10],\n",
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"import pandas as pd\n",
"\n",
"\n",
"# Определяем числовые признаки\n",
"numeric_features = X_train.select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
"\n",
"# Установка random_state\n",
"random_state = 42\n",
"\n",
"# Определение трансформера\n",
"pipeline_end = ColumnTransformer([\n",
" ('numeric', StandardScaler(), numeric_features),\n",
" # Добавьте другие трансформеры, если требуется\n",
"])\n",
"\n",
"# Объявление модели\n",
"optimized_model = RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"sqrt\",\n",
" n_estimators=50,\n",
")\n",
"\n",
"# Создание пайплайна с корректными шагами\n",
"result = {}\n",
"\n",
"# Обучение модели\n",
"result[\"pipeline\"] = Pipeline([\n",
" (\"pipeline\", pipeline_end),\n",
" (\"model\", optimized_model)\n",
"]).fit(X_train, y_train.values.ravel())\n",
"\n",
"# Прогнозирование и расчет метрик\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"# Метрики для оценки модели\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_103c9_row0_col0, #T_103c9_row0_col1, #T_103c9_row0_col2, #T_103c9_row0_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_103c9_row0_col4, #T_103c9_row0_col5, #T_103c9_row0_col6, #T_103c9_row0_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_103c9_row1_col0, #T_103c9_row1_col1, #T_103c9_row1_col2, #T_103c9_row1_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_103c9_row1_col4, #T_103c9_row1_col5, #T_103c9_row1_col6, #T_103c9_row1_col7 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_103c9\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_103c9_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_103c9_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_103c9_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_103c9_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_103c9_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_103c9_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_103c9_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_103c9_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" <th class=\"blank col7\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_103c9_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_103c9_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_103c9_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_103c9_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_103c9_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_103c9_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_103c9_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_103c9_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_103c9_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_103c9_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_103c9_row1_col0\" class=\"data row1 col0\" >0.833191</td>\n",
" <td id=\"T_103c9_row1_col1\" class=\"data row1 col1\" >0.862500</td>\n",
" <td id=\"T_103c9_row1_col2\" class=\"data row1 col2\" >0.138433</td>\n",
" <td id=\"T_103c9_row1_col3\" class=\"data row1 col3\" >0.156109</td>\n",
" <td id=\"T_103c9_row1_col4\" class=\"data row1 col4\" >0.913456</td>\n",
" <td id=\"T_103c9_row1_col5\" class=\"data row1 col5\" >0.915456</td>\n",
" <td id=\"T_103c9_row1_col6\" class=\"data row1 col6\" >0.237420</td>\n",
" <td id=\"T_103c9_row1_col7\" class=\"data row1 col7\" >0.264368</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1dd76b55010>"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_6af3a_row0_col0, #T_6af3a_row0_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_6af3a_row0_col2, #T_6af3a_row0_col3, #T_6af3a_row0_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6af3a_row1_col0, #T_6af3a_row1_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6af3a_row1_col2, #T_6af3a_row1_col3, #T_6af3a_row1_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_6af3a\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_6af3a_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_6af3a_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_6af3a_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_6af3a_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_6af3a_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_6af3a_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_6af3a_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_6af3a_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_6af3a_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_6af3a_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_6af3a_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_6af3a_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_6af3a_row1_col0\" class=\"data row1 col0\" >0.915456</td>\n",
" <td id=\"T_6af3a_row1_col1\" class=\"data row1 col1\" >0.264368</td>\n",
" <td id=\"T_6af3a_row1_col2\" class=\"data row1 col2\" >0.927493</td>\n",
" <td id=\"T_6af3a_row1_col3\" class=\"data row1 col3\" >0.241751</td>\n",
" <td id=\"T_6af3a_row1_col4\" class=\"data row1 col4\" >0.345694</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1dd76b54c50>"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA5cAAAGsCAYAAABJt1OiAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABwNElEQVR4nO3deVxU5eLH8e8AsogOuAGSqJi7oqZ2jRZXEpdK0+xaXpfcbiWVWlrd1NzKstzTrCzNm/601dTMJM2l3BLDPSs3SAUzBERDlpnfH1xGJ1EZ5+jA8Hm/Xuf1a855zjPPmZ/Xr895nvMck9VqtQoAAAAAACd4uLoBAAAAAIDij84lAAAAAMBpdC4BAAAAAE6jcwkAAAAAcBqdSwAAAACA0+hcAgAAAACcRucSAAAAAOA0OpcAAAAAAKd5uboBAIDiJTMzU1lZWYbV5+3tLV9fX8PqAwDAEeSacehcAgAKLTMzU+HVyijpVK5hdYaEhOjIkSMlNogBAK5DrhmLziUAoNCysrKUdCpXR+KqyVzW+Scr0s9aFN7smLKyskpkCAMAXItcMxadSwCAw8xlPQwJYQAAigJyzRh0LgEADsu1WpRrNaYeAABcjVwzBt1zAIDDLLIatgEA4GquyrWNGzfq/vvvV2hoqEwmk5YtW3ZZmQMHDuiBBx5QQECA/P39dfvttyshIcF2PDMzU0OGDFGFChVUpkwZde/eXcnJyXZ1JCQkqHPnzipdurSCgoI0YsQI5eTk2JVZv369mjZtKh8fH9WsWVMLFixw6FokOpcAgGLE3UIYAFCynTt3To0bN9bs2bMLPH7o0CHdfffdqlu3rtavX6/du3dr9OjRds9zDhs2TCtWrNAnn3yiDRs26MSJE+rWrZvteG5urjp37qysrCxt3rxZH374oRYsWKAxY8bYyhw5ckSdO3dWmzZtFB8fr6FDh2rgwIH65ptvHLoepsUCABxmkUVGTPxxtJb8EO7fv79dcObLD+EBAwZo3LhxMpvN2rdv32Uh/NVXX+mTTz5RQECAYmJi1K1bN/3www+SLoZwSEiINm/erJMnT6pPnz4qVaqUXn31VUkXQ/jxxx/XokWLtHbtWg0cOFCVK1dWdHS0E78IAMAVXJVrHTt2VMeOHa94/KWXXlKnTp00efJk275bb73V9t9paWl6//33tXjxYrVt21aSNH/+fNWrV09bt27VHXfcoTVr1mj//v369ttvFRwcrCZNmmjChAl6/vnnNXbsWHl7e2vu3LkKDw/XlClTJEn16tXT999/r2nTpjmUa4xcAgAclmu1GrY5omPHjpo4caIefPDBAo9fGsK33Xabbr31Vj3wwAMKCgqSdDGEp06dqrZt26pZs2aaP3++Nm/erK1bt0qSLYQ/+ugjNWnSRB07dtSECRM0e/Zs23vQLg3hevXqKSYmRg899JCmTZvmxK8KAHAVo3MtPT3dbrtw4YLDbbJYLPrqq69Uu3ZtRUdHKygoSC1atLCbtRMXF6fs7GxFRUXZ9tWtW1dVq1bVli1bJElbtmxRRESEgoODbWWio6OVnp6uffv22cpcWkd+mfw6CovOJQDA5UpqCAMA3FNYWJgCAgJs26RJkxyu49SpU8rIyNBrr72mDh06aM2aNXrwwQfVrVs3bdiwQZKUlJQkb29vBQYG2p0bHByspKQkW5lLMy3/eP6xq5VJT0/XX3/9Veg207kEADjM6IUPSmoIAwCKBqNzLTExUWlpabbtxRdfdLxNlrwptl26dNGwYcPUpEkTvfDCC7rvvvs0d+5cQ6/fKDxzCQBwmEVW5Rqw0uulIWw2m237fXx8HK/rbyEsSU2aNNHmzZs1d+5ctWrVyun2AgDck9G5Zjab7XLtelSsWFFeXl6qX7++3f785yElKSQkRFlZWUpNTbW7cZqcnKyQkBBbme3bt9vVkb+Q3aVl/r64XXJyssxms/z8/ArdZkYuAQAulx/C+dv1dC6vFsL5q8VeGsKX+nsIFxSw+ceuVsbREAYA4Eq8vb11++236+DBg3b7f/nlF1WrVk2S1KxZM5UqVUpr1661HT948KASEhIUGRkpSYqMjNSePXt06tQpW5nY2FiZzWZbZkZGRtrVkV8mv47ConMJAHBYUXzPZXEMYQBA0eCqXMvIyFB8fLzi4+Ml5a1GHh8fb7spOmLECC1dulTvvfeefvvtN7311ltasWKFnnzySUlSQECABgwYoOHDh+u7775TXFycHnvsMUVGRuqOO+6QJLVv317169dX7969tWvXLn3zzTcaNWqUhgwZYruZ+/jjj+vw4cMaOXKkfv75Z82ZM0cff/yxbSZQYTEtFgBQbGRkZOi3336zfc4P4fLly6tq1aoaMWKE/vnPf6ply5Zq06aNVq9erRUrVmj9+vWS7EO4fPnyMpvNeuqpp64YwpMnT1ZSUlKBIfzWW29p5MiR6t+/v9atW6ePP/5YX3311U3/TQAAxdeOHTvUpk0b2+fhw4dLkvr27asFCxbowQcf1Ny5czVp0iQ9/fTTqlOnjj777DPdfffdtnOmTZsmDw8Pde/eXRcuXFB0dLTmzJljO+7p6amVK1fqiSeeUGRkpPz9/dW3b1+NHz/eViY8PFxfffWVhg0bphkzZqhKlSqaN2+ew6/XMlmtDq4DDwAosdLT0xUQEKBfDgSrbFnnJ7+cPWtR7XrJSktLK9SzKevXr7cL4Xz5ISxJH3zwgSZNmqTff/9dderU0bhx49SlSxdb2czMTD377LP6v//7P7sQzp/yKknHjh3TE088ofXr19tC+LXXXpOX18V7suvXr9ewYcO0f/9+ValSRaNHj1a/fv2u/8cAANx0rs41d0PnEgBQaPkh/LOBIVy3BIcwAMC1yDVj8cwlAAAAAMBpPHMJAHBYrkFLthtRBwAAziLXjEHnEgDgsFxr3mZEPQAAuBq5ZgymxQIAAAAAnMbIJQDAYZb/bUbUAwCAq5FrxqBzCQBwmEUm5cpkSD0AALgauWYMpsUCAAAAAJzGyCUAwGEWa95mRD0AALgauWYMRi4BAAAAAE5j5BIA4LBcg55NMaIOAACcRa4Zg84lAMBhhDAAwJ2Qa8ZgWiwAAAAAwGmMXAIAHGaxmmSxGrBkuwF1AADgLHLNGHQuAQAOY/oQAMCdkGvGYFosAAAAAMBpjFwCAByWKw/lGnB/MteAtgAA4CxyzRh0LgEADrMa9GyKtYQ/mwIAKBrINWMwLRYAAAAA4DRGLgEADmPhAwCAOyHXjEHnEgDgsFyrh3KtBjybYjWgMQAAOIlcMwbTYgEAAAAATmPkEgDgMItMshhwf9KiEn6LFwBQJJBrxmDkEgAAAADgNEYuAQAOY+EDAIA7IdeMQecSAOAw4xY+KNnThwAARQO5ZgymxQIAAAAAnMbIJQDAYXkLHzg/9ceIOgAAcBa5Zgw6lwAAh1nkoVxW1QMAuAlyzRhMiwUAAAAAOI2RSwCAw1j4AADgTsg1Y9C5BAA4zCIPXjYNAHAb5JoxmBYLAAAAAHAaI5cAAIflWk3KtRrwsmkD6gAAwFnkmjEYuQQAAAAAOI2RSwCAw3INWrI9t4Q/mwIAKBrINWPQuQQAOMxi9ZDFgFX1LCV8VT0AQNFArhmDabEAAAAAAKfRuQQAOCx/+pARGwAAruaqXNu4caPuv/9+hYaGymQyadmyZVcs+/jjj8tkMmn69Ol2+1NSUtSrVy+ZzWYFBgZqwIABysjIsCuze/du3XPPPfL19VVYWJgmT558Wf2ffPKJ6tatK19fX0VERGjVqlUOXYtE5xIAcB0suriynjObxcHvdbcQBgAUDa7KtXPnzqlx48aaPXv2Vct98cUX2rp1q0JDQy871qtXL+3bt0+xsbFauXKlNm7cqMGDB9uOp6enq3379qpWrZri4uL0xhtvaOzYsXr33XdtZTZv3qxHHnlEAwYM0E8//aSuXbuqa9eu2rt3r0PXQ+cSAFBsuFsIAwBKto4dO2rixIl68MEHr1jm+PHjeuqpp7Ro0SKVKlXK7tiBAwe0evVqzZs3Ty1atNDdd9+tWbNmacmSJTpx4oQkadGiRcrKytIHH3y
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"hazardous\", \"safe\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В желтых квадрате мы наблюдаем значение 16400, что обозначает количество правильно классифицированных объектов, отнесенных к классу \"hazardsous\". Это свидетельствует о том, что модель успешно идентифицирует объекты этого класса, минимизируя количество ложных положительных срабатываний.\n",
"\n",
"В фиолетвом квадрате значение 276 указывает на количество правильно классифицированных объектов, отнесенных к классу \"More\". Это является показателем не такой высокой точности модели в определении объектов данного класса."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Определение достижимого уровня качества модели для второй задачи (задача регрессии)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['id', 'name', 'est_diameter_min', 'est_diameter_max',\n",
" 'relative_velocity', 'miss_distance', 'orbiting_body', 'sentry_object',\n",
" 'absolute_magnitude', 'hazardous'],\n",
" dtype='object')\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>name</th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>orbiting_body</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2162635</td>\n",
" <td>162635 (2000 SS164)</td>\n",
" <td>1.198271</td>\n",
" <td>2.679415</td>\n",
" <td>13569.249224</td>\n",
" <td>5.483974e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>16.73</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2277475</td>\n",
" <td>277475 (2005 WK4)</td>\n",
" <td>0.265800</td>\n",
" <td>0.594347</td>\n",
" <td>73588.726663</td>\n",
" <td>6.143813e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>20.00</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2512244</td>\n",
" <td>512244 (2015 YE18)</td>\n",
" <td>0.722030</td>\n",
" <td>1.614507</td>\n",
" <td>114258.692129</td>\n",
" <td>4.979872e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>17.83</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3596030</td>\n",
" <td>(2012 BV13)</td>\n",
" <td>0.096506</td>\n",
" <td>0.215794</td>\n",
" <td>24764.303138</td>\n",
" <td>2.543497e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>22.20</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3667127</td>\n",
" <td>(2014 GE35)</td>\n",
" <td>0.255009</td>\n",
" <td>0.570217</td>\n",
" <td>42737.733765</td>\n",
" <td>4.627557e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>20.09</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90831</th>\n",
" <td>3763337</td>\n",
" <td>(2016 VX1)</td>\n",
" <td>0.026580</td>\n",
" <td>0.059435</td>\n",
" <td>52078.886692</td>\n",
" <td>1.230039e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>25.00</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90832</th>\n",
" <td>3837603</td>\n",
" <td>(2019 AD3)</td>\n",
" <td>0.016771</td>\n",
" <td>0.037501</td>\n",
" <td>46114.605073</td>\n",
" <td>5.432121e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>26.00</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90833</th>\n",
" <td>54017201</td>\n",
" <td>(2020 JP3)</td>\n",
" <td>0.031956</td>\n",
" <td>0.071456</td>\n",
" <td>7566.807732</td>\n",
" <td>2.840077e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>24.60</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90834</th>\n",
" <td>54115824</td>\n",
" <td>(2021 CN5)</td>\n",
" <td>0.007321</td>\n",
" <td>0.016370</td>\n",
" <td>69199.154484</td>\n",
" <td>6.869206e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>27.80</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90835</th>\n",
" <td>54205447</td>\n",
" <td>(2021 TW7)</td>\n",
" <td>0.039862</td>\n",
" <td>0.089133</td>\n",
" <td>27024.455553</td>\n",
" <td>5.977213e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>24.12</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>90836 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" id name est_diameter_min est_diameter_max \\\n",
"0 2162635 162635 (2000 SS164) 1.198271 2.679415 \n",
"1 2277475 277475 (2005 WK4) 0.265800 0.594347 \n",
"2 2512244 512244 (2015 YE18) 0.722030 1.614507 \n",
"3 3596030 (2012 BV13) 0.096506 0.215794 \n",
"4 3667127 (2014 GE35) 0.255009 0.570217 \n",
"... ... ... ... ... \n",
"90831 3763337 (2016 VX1) 0.026580 0.059435 \n",
"90832 3837603 (2019 AD3) 0.016771 0.037501 \n",
"90833 54017201 (2020 JP3) 0.031956 0.071456 \n",
"90834 54115824 (2021 CN5) 0.007321 0.016370 \n",
"90835 54205447 (2021 TW7) 0.039862 0.089133 \n",
"\n",
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
"0 13569.249224 5.483974e+07 Earth False \n",
"1 73588.726663 6.143813e+07 Earth False \n",
"2 114258.692129 4.979872e+07 Earth False \n",
"3 24764.303138 2.543497e+07 Earth False \n",
"4 42737.733765 4.627557e+07 Earth False \n",
"... ... ... ... ... \n",
"90831 52078.886692 1.230039e+07 Earth False \n",
"90832 46114.605073 5.432121e+07 Earth False \n",
"90833 7566.807732 2.840077e+07 Earth False \n",
"90834 69199.154484 6.869206e+07 Earth False \n",
"90835 27024.455553 5.977213e+07 Earth False \n",
"\n",
" absolute_magnitude hazardous \n",
"0 16.73 False \n",
"1 20.00 True \n",
"2 17.83 False \n",
"3 22.20 False \n",
"4 20.09 True \n",
"... ... ... \n",
"90831 25.00 False \n",
"90832 26.00 False \n",
"90833 24.60 False \n",
"90834 27.80 False \n",
"90835 24.12 False \n",
"\n",
"[90836 rows x 10 columns]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import set_config\n",
"\n",
"random_state=42\n",
"set_config(transform_output=\"pandas\")\n",
"df = pd.read_csv(\".//static//csv//neo.csv\")\n",
"print(df.columns)\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи регрессии "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>name</th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>orbiting_body</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>35538</th>\n",
" <td>3826685</td>\n",
" <td>(2018 PR10)</td>\n",
" <td>0.038420</td>\n",
" <td>0.085909</td>\n",
" <td>91103.489666</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>24.20</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40393</th>\n",
" <td>2277830</td>\n",
" <td>277830 (2006 HR29)</td>\n",
" <td>0.192555</td>\n",
" <td>0.430566</td>\n",
" <td>28359.611312</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>20.70</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>58540</th>\n",
" <td>3638201</td>\n",
" <td>(2013 HT25)</td>\n",
" <td>0.004619</td>\n",
" <td>0.010329</td>\n",
" <td>107351.426865</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>28.80</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61670</th>\n",
" <td>3836282</td>\n",
" <td>(2018 WR)</td>\n",
" <td>0.015295</td>\n",
" <td>0.034201</td>\n",
" <td>21423.536884</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>26.20</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11435</th>\n",
" <td>3802002</td>\n",
" <td>(2018 FU1)</td>\n",
" <td>0.011603</td>\n",
" <td>0.025944</td>\n",
" <td>69856.053840</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>26.80</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6265</th>\n",
" <td>2530151</td>\n",
" <td>530151 (2011 AW55)</td>\n",
" <td>0.211132</td>\n",
" <td>0.472106</td>\n",
" <td>88209.754856</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>20.50</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>54886</th>\n",
" <td>3831736</td>\n",
" <td>(2018 TD5)</td>\n",
" <td>0.035039</td>\n",
" <td>0.078350</td>\n",
" <td>58758.452153</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>24.40</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>76820</th>\n",
" <td>2512234</td>\n",
" <td>512234 (2015 VO66)</td>\n",
" <td>0.211132</td>\n",
" <td>0.472106</td>\n",
" <td>52355.509176</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>20.50</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>860</th>\n",
" <td>54054466</td>\n",
" <td>(2020 SG1)</td>\n",
" <td>0.282199</td>\n",
" <td>0.631015</td>\n",
" <td>50527.379563</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>19.87</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15795</th>\n",
" <td>3773929</td>\n",
" <td>(2017 GL7)</td>\n",
" <td>0.075258</td>\n",
" <td>0.168283</td>\n",
" <td>22527.647871</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>22.74</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>72668 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" id name est_diameter_min est_diameter_max \\\n",
"35538 3826685 (2018 PR10) 0.038420 0.085909 \n",
"40393 2277830 277830 (2006 HR29) 0.192555 0.430566 \n",
"58540 3638201 (2013 HT25) 0.004619 0.010329 \n",
"61670 3836282 (2018 WR) 0.015295 0.034201 \n",
"11435 3802002 (2018 FU1) 0.011603 0.025944 \n",
"... ... ... ... ... \n",
"6265 2530151 530151 (2011 AW55) 0.211132 0.472106 \n",
"54886 3831736 (2018 TD5) 0.035039 0.078350 \n",
"76820 2512234 512234 (2015 VO66) 0.211132 0.472106 \n",
"860 54054466 (2020 SG1) 0.282199 0.631015 \n",
"15795 3773929 (2017 GL7) 0.075258 0.168283 \n",
"\n",
" relative_velocity orbiting_body sentry_object absolute_magnitude \\\n",
"35538 91103.489666 Earth False 24.20 \n",
"40393 28359.611312 Earth False 20.70 \n",
"58540 107351.426865 Earth False 28.80 \n",
"61670 21423.536884 Earth False 26.20 \n",
"11435 69856.053840 Earth False 26.80 \n",
"... ... ... ... ... \n",
"6265 88209.754856 Earth False 20.50 \n",
"54886 58758.452153 Earth False 24.40 \n",
"76820 52355.509176 Earth False 20.50 \n",
"860 50527.379563 Earth False 19.87 \n",
"15795 22527.647871 Earth False 22.74 \n",
"\n",
" hazardous \n",
"35538 False \n",
"40393 False \n",
"58540 False \n",
"61670 False \n",
"11435 False \n",
"... ... \n",
"6265 False \n",
"54886 False \n",
"76820 True \n",
"860 False \n",
"15795 False \n",
"\n",
"[72668 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>miss_distance</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>35538</th>\n",
" <td>6.350550e+07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40393</th>\n",
" <td>2.868167e+07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>58540</th>\n",
" <td>5.388098e+04</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61670</th>\n",
" <td>5.103884e+07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11435</th>\n",
" <td>7.360836e+07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6265</th>\n",
" <td>4.034289e+07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>54886</th>\n",
" <td>4.389994e+06</td>\n",
" </tr>\n",
" <tr>\n",
" <th>76820</th>\n",
" <td>4.380532e+07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>860</th>\n",
" <td>5.837007e+07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15795</th>\n",
" <td>2.281469e+07</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>72668 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" miss_distance\n",
"35538 6.350550e+07\n",
"40393 2.868167e+07\n",
"58540 5.388098e+04\n",
"61670 5.103884e+07\n",
"11435 7.360836e+07\n",
"... ...\n",
"6265 4.034289e+07\n",
"54886 4.389994e+06\n",
"76820 4.380532e+07\n",
"860 5.837007e+07\n",
"15795 2.281469e+07\n",
"\n",
"[72668 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>name</th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>orbiting_body</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>20406</th>\n",
" <td>3943344</td>\n",
" <td>(2019 YT1)</td>\n",
" <td>0.024241</td>\n",
" <td>0.054205</td>\n",
" <td>22148.962596</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>25.20</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>74443</th>\n",
" <td>3879239</td>\n",
" <td>(2019 US)</td>\n",
" <td>0.012722</td>\n",
" <td>0.028447</td>\n",
" <td>26477.211836</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>26.60</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>74306</th>\n",
" <td>3879244</td>\n",
" <td>(2019 UU)</td>\n",
" <td>0.013322</td>\n",
" <td>0.029788</td>\n",
" <td>33770.201397</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>26.50</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>45943</th>\n",
" <td>2481965</td>\n",
" <td>481965 (2009 EB1)</td>\n",
" <td>0.193444</td>\n",
" <td>0.432554</td>\n",
" <td>43599.575296</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>20.69</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>62859</th>\n",
" <td>3789471</td>\n",
" <td>(2017 WJ1)</td>\n",
" <td>0.044112</td>\n",
" <td>0.098637</td>\n",
" <td>36398.080883</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>23.90</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>51634</th>\n",
" <td>3694131</td>\n",
" <td>(2014 UF56)</td>\n",
" <td>0.008801</td>\n",
" <td>0.019681</td>\n",
" <td>57414.305699</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>27.40</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>85083</th>\n",
" <td>54235475</td>\n",
" <td>(2022 AG1)</td>\n",
" <td>0.024920</td>\n",
" <td>0.055724</td>\n",
" <td>50882.935767</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>25.14</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38905</th>\n",
" <td>3775176</td>\n",
" <td>(2017 LD)</td>\n",
" <td>0.008405</td>\n",
" <td>0.018795</td>\n",
" <td>24954.754212</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>27.50</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16144</th>\n",
" <td>2434734</td>\n",
" <td>434734 (2006 FX)</td>\n",
" <td>0.265800</td>\n",
" <td>0.594347</td>\n",
" <td>57455.404666</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>20.00</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>54508</th>\n",
" <td>3170208</td>\n",
" <td>(2003 YG136)</td>\n",
" <td>0.023150</td>\n",
" <td>0.051765</td>\n",
" <td>72602.093427</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>25.30</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>18168 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" id name est_diameter_min est_diameter_max \\\n",
"20406 3943344 (2019 YT1) 0.024241 0.054205 \n",
"74443 3879239 (2019 US) 0.012722 0.028447 \n",
"74306 3879244 (2019 UU) 0.013322 0.029788 \n",
"45943 2481965 481965 (2009 EB1) 0.193444 0.432554 \n",
"62859 3789471 (2017 WJ1) 0.044112 0.098637 \n",
"... ... ... ... ... \n",
"51634 3694131 (2014 UF56) 0.008801 0.019681 \n",
"85083 54235475 (2022 AG1) 0.024920 0.055724 \n",
"38905 3775176 (2017 LD) 0.008405 0.018795 \n",
"16144 2434734 434734 (2006 FX) 0.265800 0.594347 \n",
"54508 3170208 (2003 YG136) 0.023150 0.051765 \n",
"\n",
" relative_velocity orbiting_body sentry_object absolute_magnitude \\\n",
"20406 22148.962596 Earth False 25.20 \n",
"74443 26477.211836 Earth False 26.60 \n",
"74306 33770.201397 Earth False 26.50 \n",
"45943 43599.575296 Earth False 20.69 \n",
"62859 36398.080883 Earth False 23.90 \n",
"... ... ... ... ... \n",
"51634 57414.305699 Earth False 27.40 \n",
"85083 50882.935767 Earth False 25.14 \n",
"38905 24954.754212 Earth False 27.50 \n",
"16144 57455.404666 Earth False 20.00 \n",
"54508 72602.093427 Earth False 25.30 \n",
"\n",
" hazardous \n",
"20406 False \n",
"74443 False \n",
"74306 False \n",
"45943 False \n",
"62859 False \n",
"... ... \n",
"51634 False \n",
"85083 False \n",
"38905 False \n",
"16144 True \n",
"54508 False \n",
"\n",
"[18168 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>miss_distance</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>20406</th>\n",
" <td>5.028574e+07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>74443</th>\n",
" <td>1.683201e+06</td>\n",
" </tr>\n",
" <tr>\n",
" <th>74306</th>\n",
" <td>3.943220e+06</td>\n",
" </tr>\n",
" <tr>\n",
" <th>45943</th>\n",
" <td>7.346837e+07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>62859</th>\n",
" <td>6.352916e+07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>51634</th>\n",
" <td>1.987273e+07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>85083</th>\n",
" <td>3.119646e+07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38905</th>\n",
" <td>1.111942e+07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16144</th>\n",
" <td>8.501684e+06</td>\n",
" </tr>\n",
" <tr>\n",
" <th>54508</th>\n",
" <td>4.624727e+07</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>18168 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" miss_distance\n",
"20406 5.028574e+07\n",
"74443 1.683201e+06\n",
"74306 3.943220e+06\n",
"45943 7.346837e+07\n",
"62859 6.352916e+07\n",
"... ...\n",
"51634 1.987273e+07\n",
"85083 3.119646e+07\n",
"38905 1.111942e+07\n",
"16144 8.501684e+06\n",
"54508 4.624727e+07\n",
"\n",
"[18168 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_into_train_test(\n",
" df_input: DataFrame,\n",
" target_colname: str = \"miss_distance\",\n",
" frac_train: float = 0.8,\n",
" random_state: int = None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if not (0 < frac_train < 1):\n",
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
" \n",
" # Проверка наличия целевого признака\n",
" if target_colname not in df_input.columns:\n",
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
" \n",
" # Разделяем данные на признаки и целевую переменную\n",
" X = df_input.drop(columns=[target_colname]) # Признаки\n",
" y = df_input[[target_colname]] # Целевая переменная\n",
"\n",
" # Разделяем данные на обучающую и тестовую выборки\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y,\n",
" test_size=(1.0 - frac_train),\n",
" random_state=random_state\n",
" )\n",
" \n",
" return X_train, X_test, y_train, y_test\n",
"\n",
"# Применение функции для разделения данных\n",
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
" df, \n",
" target_colname=\"miss_distance\", \n",
" frac_train=0.8, \n",
" random_state=42\n",
")\n",
"\n",
"# Для отображения результатов\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование конвейера для решения задачи регрессии"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from sklearn.ensemble import RandomForestRegressor # Пример регрессионной модели\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.pipeline import make_pipeline\n",
"\n",
"class EarthObjectsFeatures(BaseEstimator, TransformerMixin):\n",
" def __init__(self):\n",
" pass\n",
" \n",
" def fit(self, X, y=None):\n",
" return self\n",
"\n",
" def transform(self, X, y=None):\n",
" X[\"Length_to_Width_Ratio\"] = X[\"x\"] / X[\"y\"]\n",
" return X\n",
"\n",
" def get_feature_names_out(self, features_in):\n",
" return np.append(features_in, [\"Length_to_Width_Ratio\"], axis=0)\n",
"\n",
"# Указываем столбцы, которые нужно удалить и обрабатывать\n",
"columns_to_drop = [\"name\", \"orbiting_body\"]\n",
"num_columns = [\"est_diameter_min\", \"est_diameter_max\",\n",
" \"relative_velocity\", \"sentry_object\",\n",
" \"absolute_magnitude\", \"hazardous\"]\n",
"cat_columns = [] \n",
"\n",
"# Определяем предобработку для численных данных\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"# Определяем предобработку для категориальных данных\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"# Подготовка признаков с использованием ColumnTransformer\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"# Удаление нежелательных столбцов\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"# Постобработка признаков\n",
"features_postprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"preprocessing_cat\", preprocessing_cat, [\"Cabin_type\"]), \n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"# Создание окончательного конвейера\n",
"pipeline = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" (\"model\", RandomForestRegressor()) # Выбор модели для обучения\n",
" ]\n",
")\n",
"\n",
"# Использование конвейера\n",
"def train_pipeline(X, y):\n",
" pipeline.fit(X, y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование набора моделей для регрессии"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
"\n",
"random_state = 9\n",
"\n",
"models = {\n",
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPRegressor(\n",
" activation=\"tanh\",\n",
" hidden_layer_sizes=(3,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: linear\n"
]
},
{
"ename": "ValueError",
"evalue": "could not convert string to float: '(2018 PR10)'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[9], line 8\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m model_name \u001b[38;5;129;01min\u001b[39;00m models\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[0;32m 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModel: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m----> 8\u001b[0m fitted_model \u001b[38;5;241m=\u001b[39m \u001b[43mmodels\u001b[49m\u001b[43m[\u001b[49m\u001b[43mmodel_name\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodel\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mX_train\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mravel\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 11\u001b[0m y_train_pred \u001b[38;5;241m=\u001b[39m fitted_model\u001b[38;5;241m.\u001b[39mpredict(X_train\u001b[38;5;241m.\u001b[39mvalues)\n\u001b[0;32m 12\u001b[0m y_test_pred \u001b[38;5;241m=\u001b[39m fitted_model\u001b[38;5;241m.\u001b[39mpredict(X_test\u001b[38;5;241m.\u001b[39mvalues)\n",
"File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473\u001b[0m, in \u001b[0;36m_fit_context.<locals>.decorator.<locals>.wrapper\u001b[1;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1466\u001b[0m estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[0;32m 1468\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[0;32m 1469\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[0;32m 1470\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[0;32m 1471\u001b[0m )\n\u001b[0;32m 1472\u001b[0m ):\n\u001b[1;32m-> 1473\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\linear_model\\_base.py:609\u001b[0m, in \u001b[0;36mLinearRegression.fit\u001b[1;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[0;32m 605\u001b[0m n_jobs_ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_jobs\n\u001b[0;32m 607\u001b[0m accept_sparse \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpositive \u001b[38;5;28;01melse\u001b[39;00m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcsr\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcsc\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcoo\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m--> 609\u001b[0m X, y \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_data\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 610\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 611\u001b[0m \u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 612\u001b[0m \u001b[43m \u001b[49m\u001b[43maccept_sparse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maccept_sparse\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 613\u001b[0m \u001b[43m \u001b[49m\u001b[43my_numeric\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 614\u001b[0m \u001b[43m \u001b[49m\u001b[43mmulti_output\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 615\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_writeable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 616\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 618\u001b[0m has_sw \u001b[38;5;241m=\u001b[39m sample_weight \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m 619\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_sw:\n",
"File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:650\u001b[0m, in \u001b[0;36mBaseEstimator._validate_data\u001b[1;34m(self, X, y, reset, validate_separately, cast_to_ndarray, **check_params)\u001b[0m\n\u001b[0;32m 648\u001b[0m y \u001b[38;5;241m=\u001b[39m check_array(y, input_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mcheck_y_params)\n\u001b[0;32m 649\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 650\u001b[0m X, y \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_X_y\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mcheck_params\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 651\u001b[0m out \u001b[38;5;241m=\u001b[39m X, y\n\u001b[0;32m 653\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m no_val_X \u001b[38;5;129;01mand\u001b[39;00m check_params\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mensure_2d\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m):\n",
"File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1301\u001b[0m, in \u001b[0;36mcheck_X_y\u001b[1;34m(X, y, accept_sparse, accept_large_sparse, dtype, order, copy, force_writeable, force_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, estimator)\u001b[0m\n\u001b[0;32m 1296\u001b[0m estimator_name \u001b[38;5;241m=\u001b[39m _check_estimator_name(estimator)\n\u001b[0;32m 1297\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 1298\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mestimator_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m requires y to be passed, but the target y is None\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1299\u001b[0m )\n\u001b[1;32m-> 1301\u001b[0m X \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_array\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1302\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1303\u001b[0m \u001b[43m \u001b[49m\u001b[43maccept_sparse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maccept_sparse\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1304\u001b[0m \u001b[43m \u001b[49m\u001b[43maccept_large_sparse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maccept_large_sparse\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1305\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1306\u001b[0m \u001b[43m \u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1307\u001b[0m \u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1308\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_writeable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_writeable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1309\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_all_finite\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_all_finite\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1310\u001b[0m \u001b[43m \u001b[49m\u001b[43mensure_2d\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensure_2d\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1311\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_nd\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mallow_nd\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1312\u001b[0m \u001b[43m \u001b[49m\u001b[43mensure_min_samples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensure_min_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1313\u001b[0m \u001b[43m \u001b[49m\u001b[43mensure_min_features\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensure_min_features\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1314\u001b[0m \u001b[43m \u001b[49m\u001b[43mestimator\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1315\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mX\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1316\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1318\u001b[0m y \u001b[38;5;241m=\u001b[39m _check_y(y, multi_output\u001b[38;5;241m=\u001b[39mmulti_output, y_numeric\u001b[38;5;241m=\u001b[39my_numeric, estimator\u001b[38;5;241m=\u001b[39mestimator)\n\u001b[0;32m 1320\u001b[0m check_consistent_length(X, y)\n",
"File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1012\u001b[0m, in \u001b[0;36mcheck_array\u001b[1;34m(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_writeable, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)\u001b[0m\n\u001b[0;32m 1010\u001b[0m array \u001b[38;5;241m=\u001b[39m xp\u001b[38;5;241m.\u001b[39mastype(array, dtype, copy\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m 1011\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1012\u001b[0m array \u001b[38;5;241m=\u001b[39m \u001b[43m_asarray_with_order\u001b[49m\u001b[43m(\u001b[49m\u001b[43marray\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mxp\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mxp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1013\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ComplexWarning \u001b[38;5;28;01mas\u001b[39;00m complex_warning:\n\u001b[0;32m 1014\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 1015\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mComplex data not supported\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(array)\n\u001b[0;32m 1016\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mcomplex_warning\u001b[39;00m\n",
"File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\utils\\_array_api.py:745\u001b[0m, in \u001b[0;36m_asarray_with_order\u001b[1;34m(array, dtype, order, copy, xp, device)\u001b[0m\n\u001b[0;32m 743\u001b[0m array \u001b[38;5;241m=\u001b[39m numpy\u001b[38;5;241m.\u001b[39marray(array, order\u001b[38;5;241m=\u001b[39morder, dtype\u001b[38;5;241m=\u001b[39mdtype)\n\u001b[0;32m 744\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 745\u001b[0m array \u001b[38;5;241m=\u001b[39m \u001b[43mnumpy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43masarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43marray\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 747\u001b[0m \u001b[38;5;66;03m# At this point array is a NumPy ndarray. We convert it to an array\u001b[39;00m\n\u001b[0;32m 748\u001b[0m \u001b[38;5;66;03m# container that is consistent with the input's namespace.\u001b[39;00m\n\u001b[0;32m 749\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m xp\u001b[38;5;241m.\u001b[39masarray(array)\n",
"\u001b[1;31mValueError\u001b[0m: could not convert string to float: '(2018 PR10)'"
]
}
],
"source": [
"import math\n",
"from pandas import DataFrame\n",
"from sklearn import metrics\n",
"\n",
"for model_name in models.keys():\n",
" print(f\"Model: {model_name}\")\n",
"\n",
" fitted_model = models[model_name][\"model\"].fit(\n",
" X_train.values, y_train.values.ravel()\n",
" )\n",
" y_train_pred = fitted_model.predict(X_train.values)\n",
" y_test_pred = fitted_model.predict(X_test.values)\n",
" models[model_name][\"fitted\"] = fitted_model\n",
" models[model_name][\"train_preds\"] = y_train_pred\n",
" models[model_name][\"preds\"] = y_test_pred\n",
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_train, y_train_pred)\n",
" )\n",
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "aimenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}