4115 lines
430 KiB
Plaintext
4115 lines
430 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Начало 4-й лабораторной\n",
|
||
"#### Ближайшие объекты к Земле"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Index(['id', 'name', 'est_diameter_min', 'est_diameter_max',\n",
|
||
" 'relative_velocity', 'miss_distance', 'orbiting_body', 'sentry_object',\n",
|
||
" 'absolute_magnitude', 'hazardous'],\n",
|
||
" dtype='object')\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>id</th>\n",
|
||
" <th>name</th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>orbiting_body</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>2162635</td>\n",
|
||
" <td>162635 (2000 SS164)</td>\n",
|
||
" <td>1.198271</td>\n",
|
||
" <td>2.679415</td>\n",
|
||
" <td>13569.249224</td>\n",
|
||
" <td>5.483974e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>16.73</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>2277475</td>\n",
|
||
" <td>277475 (2005 WK4)</td>\n",
|
||
" <td>0.265800</td>\n",
|
||
" <td>0.594347</td>\n",
|
||
" <td>73588.726663</td>\n",
|
||
" <td>6.143813e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>20.00</td>\n",
|
||
" <td>True</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>2512244</td>\n",
|
||
" <td>512244 (2015 YE18)</td>\n",
|
||
" <td>0.722030</td>\n",
|
||
" <td>1.614507</td>\n",
|
||
" <td>114258.692129</td>\n",
|
||
" <td>4.979872e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>17.83</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>3596030</td>\n",
|
||
" <td>(2012 BV13)</td>\n",
|
||
" <td>0.096506</td>\n",
|
||
" <td>0.215794</td>\n",
|
||
" <td>24764.303138</td>\n",
|
||
" <td>2.543497e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>22.20</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>3667127</td>\n",
|
||
" <td>(2014 GE35)</td>\n",
|
||
" <td>0.255009</td>\n",
|
||
" <td>0.570217</td>\n",
|
||
" <td>42737.733765</td>\n",
|
||
" <td>4.627557e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>20.09</td>\n",
|
||
" <td>True</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90831</th>\n",
|
||
" <td>3763337</td>\n",
|
||
" <td>(2016 VX1)</td>\n",
|
||
" <td>0.026580</td>\n",
|
||
" <td>0.059435</td>\n",
|
||
" <td>52078.886692</td>\n",
|
||
" <td>1.230039e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>25.00</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90832</th>\n",
|
||
" <td>3837603</td>\n",
|
||
" <td>(2019 AD3)</td>\n",
|
||
" <td>0.016771</td>\n",
|
||
" <td>0.037501</td>\n",
|
||
" <td>46114.605073</td>\n",
|
||
" <td>5.432121e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>26.00</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90833</th>\n",
|
||
" <td>54017201</td>\n",
|
||
" <td>(2020 JP3)</td>\n",
|
||
" <td>0.031956</td>\n",
|
||
" <td>0.071456</td>\n",
|
||
" <td>7566.807732</td>\n",
|
||
" <td>2.840077e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>24.60</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90834</th>\n",
|
||
" <td>54115824</td>\n",
|
||
" <td>(2021 CN5)</td>\n",
|
||
" <td>0.007321</td>\n",
|
||
" <td>0.016370</td>\n",
|
||
" <td>69199.154484</td>\n",
|
||
" <td>6.869206e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>27.80</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90835</th>\n",
|
||
" <td>54205447</td>\n",
|
||
" <td>(2021 TW7)</td>\n",
|
||
" <td>0.039862</td>\n",
|
||
" <td>0.089133</td>\n",
|
||
" <td>27024.455553</td>\n",
|
||
" <td>5.977213e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>24.12</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>90836 rows × 10 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" id name est_diameter_min est_diameter_max \\\n",
|
||
"0 2162635 162635 (2000 SS164) 1.198271 2.679415 \n",
|
||
"1 2277475 277475 (2005 WK4) 0.265800 0.594347 \n",
|
||
"2 2512244 512244 (2015 YE18) 0.722030 1.614507 \n",
|
||
"3 3596030 (2012 BV13) 0.096506 0.215794 \n",
|
||
"4 3667127 (2014 GE35) 0.255009 0.570217 \n",
|
||
"... ... ... ... ... \n",
|
||
"90831 3763337 (2016 VX1) 0.026580 0.059435 \n",
|
||
"90832 3837603 (2019 AD3) 0.016771 0.037501 \n",
|
||
"90833 54017201 (2020 JP3) 0.031956 0.071456 \n",
|
||
"90834 54115824 (2021 CN5) 0.007321 0.016370 \n",
|
||
"90835 54205447 (2021 TW7) 0.039862 0.089133 \n",
|
||
"\n",
|
||
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
|
||
"0 13569.249224 5.483974e+07 Earth False \n",
|
||
"1 73588.726663 6.143813e+07 Earth False \n",
|
||
"2 114258.692129 4.979872e+07 Earth False \n",
|
||
"3 24764.303138 2.543497e+07 Earth False \n",
|
||
"4 42737.733765 4.627557e+07 Earth False \n",
|
||
"... ... ... ... ... \n",
|
||
"90831 52078.886692 1.230039e+07 Earth False \n",
|
||
"90832 46114.605073 5.432121e+07 Earth False \n",
|
||
"90833 7566.807732 2.840077e+07 Earth False \n",
|
||
"90834 69199.154484 6.869206e+07 Earth False \n",
|
||
"90835 27024.455553 5.977213e+07 Earth False \n",
|
||
"\n",
|
||
" absolute_magnitude hazardous \n",
|
||
"0 16.73 False \n",
|
||
"1 20.00 True \n",
|
||
"2 17.83 False \n",
|
||
"3 22.20 False \n",
|
||
"4 20.09 True \n",
|
||
"... ... ... \n",
|
||
"90831 25.00 False \n",
|
||
"90832 26.00 False \n",
|
||
"90833 24.60 False \n",
|
||
"90834 27.80 False \n",
|
||
"90835 24.12 False \n",
|
||
"\n",
|
||
"[90836 rows x 10 columns]"
|
||
]
|
||
},
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import seaborn as sns\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"from sklearn import set_config\n",
|
||
"\n",
|
||
"set_config(transform_output=\"pandas\")\n",
|
||
"df = pd.read_csv(\".//static//csv//neo.csv\")\n",
|
||
"print(df.columns)\n",
|
||
"df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Бизнес-цели:\n",
|
||
"\n",
|
||
"1. Идентификация потенциально опасных объектов\n",
|
||
"\n",
|
||
"Описание: классифицировать астероиды как потенциально опасные или безопасные (используя целевой признак \"hazardous\"). Эта задача актуальна для оценки рисков и подготовки соответствующих действий по защите Земли.\n",
|
||
"\n",
|
||
"2. Прогнозирование минимального расстояния до Земли\n",
|
||
"\n",
|
||
"Описание: предсказать минимальное расстояние до Земли для новых объектов на основе характеристик астероида (скорости, размера и других параметров). Это позволит планировать исследования и наблюдения в зависимости от опасности. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Определение достижимого уровня качества модели для первой задачи "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
|
||
"\n",
|
||
"Целевой признак -- hazardous"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>id</th>\n",
|
||
" <th>name</th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>orbiting_body</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>2639</th>\n",
|
||
" <td>3634614</td>\n",
|
||
" <td>(2013 GT66)</td>\n",
|
||
" <td>0.024241</td>\n",
|
||
" <td>0.054205</td>\n",
|
||
" <td>43303.999094</td>\n",
|
||
" <td>4.814117e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>25.20</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>29138</th>\n",
|
||
" <td>54143560</td>\n",
|
||
" <td>(2021 JU1)</td>\n",
|
||
" <td>0.030238</td>\n",
|
||
" <td>0.067615</td>\n",
|
||
" <td>21770.790211</td>\n",
|
||
" <td>5.646643e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>24.72</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>36927</th>\n",
|
||
" <td>3836085</td>\n",
|
||
" <td>(2018 VQ3)</td>\n",
|
||
" <td>0.201630</td>\n",
|
||
" <td>0.450858</td>\n",
|
||
" <td>109358.123029</td>\n",
|
||
" <td>6.435051e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>20.60</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>61855</th>\n",
|
||
" <td>3769804</td>\n",
|
||
" <td>(2017 DJ34)</td>\n",
|
||
" <td>0.160160</td>\n",
|
||
" <td>0.358129</td>\n",
|
||
" <td>78494.609756</td>\n",
|
||
" <td>5.595780e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>21.10</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>15916</th>\n",
|
||
" <td>3824978</td>\n",
|
||
" <td>(2018 KS)</td>\n",
|
||
" <td>0.006991</td>\n",
|
||
" <td>0.015633</td>\n",
|
||
" <td>19077.749486</td>\n",
|
||
" <td>3.834648e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>27.90</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>29491</th>\n",
|
||
" <td>3827304</td>\n",
|
||
" <td>(2018 RR1)</td>\n",
|
||
" <td>0.002658</td>\n",
|
||
" <td>0.005943</td>\n",
|
||
" <td>19826.895880</td>\n",
|
||
" <td>3.852881e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>30.00</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>18373</th>\n",
|
||
" <td>3735468</td>\n",
|
||
" <td>(2015 WY1)</td>\n",
|
||
" <td>0.103408</td>\n",
|
||
" <td>0.231228</td>\n",
|
||
" <td>82856.544926</td>\n",
|
||
" <td>7.314334e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>22.05</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>25031</th>\n",
|
||
" <td>3802041</td>\n",
|
||
" <td>(2018 FE3)</td>\n",
|
||
" <td>0.009651</td>\n",
|
||
" <td>0.021579</td>\n",
|
||
" <td>34243.774201</td>\n",
|
||
" <td>4.257719e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>27.20</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>35456</th>\n",
|
||
" <td>3430406</td>\n",
|
||
" <td>(2008 TR10)</td>\n",
|
||
" <td>0.221083</td>\n",
|
||
" <td>0.494356</td>\n",
|
||
" <td>19557.289783</td>\n",
|
||
" <td>2.152970e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>20.40</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>14305</th>\n",
|
||
" <td>3285300</td>\n",
|
||
" <td>(2005 OG3)</td>\n",
|
||
" <td>0.298233</td>\n",
|
||
" <td>0.666868</td>\n",
|
||
" <td>20309.404706</td>\n",
|
||
" <td>1.770015e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>19.75</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>72668 rows × 10 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" id name est_diameter_min est_diameter_max \\\n",
|
||
"2639 3634614 (2013 GT66) 0.024241 0.054205 \n",
|
||
"29138 54143560 (2021 JU1) 0.030238 0.067615 \n",
|
||
"36927 3836085 (2018 VQ3) 0.201630 0.450858 \n",
|
||
"61855 3769804 (2017 DJ34) 0.160160 0.358129 \n",
|
||
"15916 3824978 (2018 KS) 0.006991 0.015633 \n",
|
||
"... ... ... ... ... \n",
|
||
"29491 3827304 (2018 RR1) 0.002658 0.005943 \n",
|
||
"18373 3735468 (2015 WY1) 0.103408 0.231228 \n",
|
||
"25031 3802041 (2018 FE3) 0.009651 0.021579 \n",
|
||
"35456 3430406 (2008 TR10) 0.221083 0.494356 \n",
|
||
"14305 3285300 (2005 OG3) 0.298233 0.666868 \n",
|
||
"\n",
|
||
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
|
||
"2639 43303.999094 4.814117e+07 Earth False \n",
|
||
"29138 21770.790211 5.646643e+07 Earth False \n",
|
||
"36927 109358.123029 6.435051e+07 Earth False \n",
|
||
"61855 78494.609756 5.595780e+07 Earth False \n",
|
||
"15916 19077.749486 3.834648e+07 Earth False \n",
|
||
"... ... ... ... ... \n",
|
||
"29491 19826.895880 3.852881e+07 Earth False \n",
|
||
"18373 82856.544926 7.314334e+07 Earth False \n",
|
||
"25031 34243.774201 4.257719e+07 Earth False \n",
|
||
"35456 19557.289783 2.152970e+07 Earth False \n",
|
||
"14305 20309.404706 1.770015e+07 Earth False \n",
|
||
"\n",
|
||
" absolute_magnitude hazardous \n",
|
||
"2639 25.20 False \n",
|
||
"29138 24.72 False \n",
|
||
"36927 20.60 False \n",
|
||
"61855 21.10 False \n",
|
||
"15916 27.90 False \n",
|
||
"... ... ... \n",
|
||
"29491 30.00 False \n",
|
||
"18373 22.05 False \n",
|
||
"25031 27.20 False \n",
|
||
"35456 20.40 False \n",
|
||
"14305 19.75 False \n",
|
||
"\n",
|
||
"[72668 rows x 10 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>2639</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>29138</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>36927</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>61855</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>15916</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>29491</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>18373</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>25031</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>35456</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>14305</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>72668 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" hazardous\n",
|
||
"2639 False\n",
|
||
"29138 False\n",
|
||
"36927 False\n",
|
||
"61855 False\n",
|
||
"15916 False\n",
|
||
"... ...\n",
|
||
"29491 False\n",
|
||
"18373 False\n",
|
||
"25031 False\n",
|
||
"35456 False\n",
|
||
"14305 False\n",
|
||
"\n",
|
||
"[72668 rows x 1 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>id</th>\n",
|
||
" <th>name</th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>orbiting_body</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>9040</th>\n",
|
||
" <td>2474532</td>\n",
|
||
" <td>474532 (2003 VG1)</td>\n",
|
||
" <td>0.472667</td>\n",
|
||
" <td>1.056915</td>\n",
|
||
" <td>21779.237137</td>\n",
|
||
" <td>3.443050e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>18.75</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>67305</th>\n",
|
||
" <td>3774018</td>\n",
|
||
" <td>(2017 HF1)</td>\n",
|
||
" <td>0.084053</td>\n",
|
||
" <td>0.187949</td>\n",
|
||
" <td>53291.016226</td>\n",
|
||
" <td>6.862591e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>22.50</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>77741</th>\n",
|
||
" <td>54269585</td>\n",
|
||
" <td>(2022 GQ2)</td>\n",
|
||
" <td>0.018220</td>\n",
|
||
" <td>0.040742</td>\n",
|
||
" <td>43089.046433</td>\n",
|
||
" <td>2.592726e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>25.82</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>81520</th>\n",
|
||
" <td>54097970</td>\n",
|
||
" <td>(2020 XS)</td>\n",
|
||
" <td>0.152952</td>\n",
|
||
" <td>0.342011</td>\n",
|
||
" <td>93246.455599</td>\n",
|
||
" <td>4.709054e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>21.20</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>508</th>\n",
|
||
" <td>3730802</td>\n",
|
||
" <td>(2015 TT238)</td>\n",
|
||
" <td>0.031956</td>\n",
|
||
" <td>0.071456</td>\n",
|
||
" <td>37708.258544</td>\n",
|
||
" <td>4.232149e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>24.60</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>28261</th>\n",
|
||
" <td>3532365</td>\n",
|
||
" <td>(2010 MH1)</td>\n",
|
||
" <td>0.139494</td>\n",
|
||
" <td>0.311918</td>\n",
|
||
" <td>37604.980238</td>\n",
|
||
" <td>7.369507e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>21.40</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1159</th>\n",
|
||
" <td>54073345</td>\n",
|
||
" <td>(2020 UE)</td>\n",
|
||
" <td>0.020728</td>\n",
|
||
" <td>0.046349</td>\n",
|
||
" <td>36720.077728</td>\n",
|
||
" <td>3.366114e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>25.54</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>48095</th>\n",
|
||
" <td>3836195</td>\n",
|
||
" <td>(2018 VT7)</td>\n",
|
||
" <td>0.006991</td>\n",
|
||
" <td>0.015633</td>\n",
|
||
" <td>7616.496535</td>\n",
|
||
" <td>6.376350e+06</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>27.90</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90234</th>\n",
|
||
" <td>3752902</td>\n",
|
||
" <td>(2016 JG12)</td>\n",
|
||
" <td>0.084053</td>\n",
|
||
" <td>0.187949</td>\n",
|
||
" <td>21894.554692</td>\n",
|
||
" <td>5.736984e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>22.50</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>12013</th>\n",
|
||
" <td>3445077</td>\n",
|
||
" <td>(2009 BM58)</td>\n",
|
||
" <td>0.038420</td>\n",
|
||
" <td>0.085909</td>\n",
|
||
" <td>49828.611609</td>\n",
|
||
" <td>4.305599e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>24.20</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>18168 rows × 10 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" id name est_diameter_min est_diameter_max \\\n",
|
||
"9040 2474532 474532 (2003 VG1) 0.472667 1.056915 \n",
|
||
"67305 3774018 (2017 HF1) 0.084053 0.187949 \n",
|
||
"77741 54269585 (2022 GQ2) 0.018220 0.040742 \n",
|
||
"81520 54097970 (2020 XS) 0.152952 0.342011 \n",
|
||
"508 3730802 (2015 TT238) 0.031956 0.071456 \n",
|
||
"... ... ... ... ... \n",
|
||
"28261 3532365 (2010 MH1) 0.139494 0.311918 \n",
|
||
"1159 54073345 (2020 UE) 0.020728 0.046349 \n",
|
||
"48095 3836195 (2018 VT7) 0.006991 0.015633 \n",
|
||
"90234 3752902 (2016 JG12) 0.084053 0.187949 \n",
|
||
"12013 3445077 (2009 BM58) 0.038420 0.085909 \n",
|
||
"\n",
|
||
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
|
||
"9040 21779.237137 3.443050e+07 Earth False \n",
|
||
"67305 53291.016226 6.862591e+07 Earth False \n",
|
||
"77741 43089.046433 2.592726e+07 Earth False \n",
|
||
"81520 93246.455599 4.709054e+07 Earth False \n",
|
||
"508 37708.258544 4.232149e+07 Earth False \n",
|
||
"... ... ... ... ... \n",
|
||
"28261 37604.980238 7.369507e+07 Earth False \n",
|
||
"1159 36720.077728 3.366114e+07 Earth False \n",
|
||
"48095 7616.496535 6.376350e+06 Earth False \n",
|
||
"90234 21894.554692 5.736984e+07 Earth False \n",
|
||
"12013 49828.611609 4.305599e+07 Earth False \n",
|
||
"\n",
|
||
" absolute_magnitude hazardous \n",
|
||
"9040 18.75 False \n",
|
||
"67305 22.50 False \n",
|
||
"77741 25.82 False \n",
|
||
"81520 21.20 False \n",
|
||
"508 24.60 False \n",
|
||
"... ... ... \n",
|
||
"28261 21.40 False \n",
|
||
"1159 25.54 False \n",
|
||
"48095 27.90 False \n",
|
||
"90234 22.50 False \n",
|
||
"12013 24.20 False \n",
|
||
"\n",
|
||
"[18168 rows x 10 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>9040</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>67305</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>77741</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>81520</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>508</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>28261</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1159</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>48095</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90234</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>12013</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>18168 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" hazardous\n",
|
||
"9040 False\n",
|
||
"67305 False\n",
|
||
"77741 False\n",
|
||
"81520 False\n",
|
||
"508 False\n",
|
||
"... ...\n",
|
||
"28261 False\n",
|
||
"1159 False\n",
|
||
"48095 False\n",
|
||
"90234 False\n",
|
||
"12013 False\n",
|
||
"\n",
|
||
"[18168 rows x 1 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from typing import Tuple\n",
|
||
"import pandas as pd\n",
|
||
"from pandas import DataFrame\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"\n",
|
||
"# Устанавливаем случайное состояние\n",
|
||
"random_state = 42\n",
|
||
"\n",
|
||
"def split_stratified_into_train_val_test(\n",
|
||
" df_input,\n",
|
||
" stratify_colname=\"y\",\n",
|
||
" frac_train=0.6,\n",
|
||
" frac_val=0.15,\n",
|
||
" frac_test=0.25,\n",
|
||
" random_state=None,\n",
|
||
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
||
" \n",
|
||
" if frac_train + frac_val + frac_test != 1.0:\n",
|
||
" raise ValueError(\n",
|
||
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
|
||
" % (frac_train, frac_val, frac_test)\n",
|
||
" )\n",
|
||
" if stratify_colname not in df_input.columns:\n",
|
||
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
|
||
" X = df_input # Contains all columns.\n",
|
||
" y = df_input[\n",
|
||
" [stratify_colname]\n",
|
||
" ] # Dataframe of just the column on which to stratify.\n",
|
||
" # Split original dataframe into train and temp dataframes.\n",
|
||
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
|
||
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
|
||
" )\n",
|
||
" if frac_val <= 0:\n",
|
||
" assert len(df_input) == len(df_train) + len(df_temp)\n",
|
||
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
|
||
" # Split the temp dataframe into val and test dataframes.\n",
|
||
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
|
||
" df_val, df_test, y_val, y_test = train_test_split(\n",
|
||
" df_temp,\n",
|
||
" y_temp,\n",
|
||
" stratify=y_temp,\n",
|
||
" test_size=relative_frac_test,\n",
|
||
" random_state=random_state,\n",
|
||
" )\n",
|
||
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
|
||
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
|
||
"\n",
|
||
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
|
||
" df, stratify_colname=\"hazardous\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
|
||
")\n",
|
||
"\n",
|
||
"display(\"X_train\", X_train)\n",
|
||
"display(\"y_train\", y_train)\n",
|
||
"\n",
|
||
"display(\"X_test\", X_test)\n",
|
||
"display(\"y_test\", y_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Формирование конвейера для классификации данных\n",
|
||
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
|
||
"\n",
|
||
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
|
||
"\n",
|
||
"features_preprocessing -- трансформер для предобработки признаков\n",
|
||
"\n",
|
||
"features_engineering -- трансформер для конструирования признаков\n",
|
||
"\n",
|
||
"drop_columns -- трансформер для удаления колонок\n",
|
||
"\n",
|
||
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"from sklearn.base import BaseEstimator, TransformerMixin\n",
|
||
"from sklearn.compose import ColumnTransformer\n",
|
||
"from sklearn.preprocessing import StandardScaler\n",
|
||
"from sklearn.impute import SimpleImputer\n",
|
||
"from sklearn.pipeline import Pipeline\n",
|
||
"from sklearn.preprocessing import OneHotEncoder\n",
|
||
"from sklearn.ensemble import RandomForestRegressor # Пример регрессионной модели\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"from sklearn.pipeline import make_pipeline\n",
|
||
"\n",
|
||
"class StarbucksFeatures(BaseEstimator, TransformerMixin):\n",
|
||
" def __init__(self):\n",
|
||
" pass\n",
|
||
" \n",
|
||
" def fit(self, X, y=None):\n",
|
||
" return self\n",
|
||
"\n",
|
||
" def transform(self, X, y=None):\n",
|
||
" # Преобразование категориальных столбцов в числовые 1/0\n",
|
||
" X[\"hazardous\"] = X[\"hazardous\"].astype(int)\n",
|
||
" X[\"sentry_object\"] = X[\"sentry_object\"].astype(int)\n",
|
||
" X[\"Length_to_Width_Ratio\"] = X[\"x\"] / X[\"y\"]\n",
|
||
" return X\n",
|
||
"\n",
|
||
" def get_feature_names_out(self, features_in):\n",
|
||
" return np.append(features_in, [\"Length_to_Width_Ratio\"], axis=0)\n",
|
||
"\n",
|
||
"# Указываем столбцы, которые нужно удалить и обрабатывать\n",
|
||
"columns_to_drop = [\"name\", \"orbiting_body\"]\n",
|
||
"num_columns = [\"est_diameter_min\", \"est_diameter_max\",\n",
|
||
" \"relative_velocity\", \"miss_distance\", \"sentry_object\",\n",
|
||
" \"absolute_magnitude\", \"hazardous\"]\n",
|
||
"cat_columns = [\"sentry_object\", \"hazardous\"]\n",
|
||
" \n",
|
||
"\n",
|
||
"# Определяем предобработку для численных данных\n",
|
||
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
||
"num_scaler = StandardScaler()\n",
|
||
"preprocessing_num = Pipeline(\n",
|
||
" [\n",
|
||
" (\"imputer\", num_imputer),\n",
|
||
" (\"scaler\", num_scaler),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"# Определяем предобработку для категориальных данных\n",
|
||
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
||
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
||
"preprocessing_cat = Pipeline(\n",
|
||
" [\n",
|
||
" (\"imputer\", cat_imputer),\n",
|
||
" (\"encoder\", cat_encoder),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"# Подготовка признаков с использованием ColumnTransformer\n",
|
||
"features_preprocessing = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
|
||
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\"\n",
|
||
")\n",
|
||
"\n",
|
||
"# Удаление нежелательных столбцов\n",
|
||
"drop_columns = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\",\n",
|
||
")\n",
|
||
"\n",
|
||
"# Постобработка признаков\n",
|
||
"features_postprocessing = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"preprocessing_cat\", preprocessing_cat, [\"Cabin_type\"]), \n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\",\n",
|
||
")\n",
|
||
"\n",
|
||
"# Создание окончательного конвейера\n",
|
||
"pipeline = Pipeline(\n",
|
||
" [\n",
|
||
" (\"features_preprocessing\", features_preprocessing),\n",
|
||
" (\"drop_columns\", drop_columns),\n",
|
||
" (\"model\", RandomForestRegressor()) # Выбор модели для обучения\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"# Использование конвейера\n",
|
||
"def train_pipeline(X, y):\n",
|
||
" pipeline.fit(X, y)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Демонстрация работы конвейера"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" <th>id</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>2639</th>\n",
|
||
" <td>-0.331616</td>\n",
|
||
" <td>-0.331616</td>\n",
|
||
" <td>-0.188160</td>\n",
|
||
" <td>0.494297</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.577785</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3634614</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>29138</th>\n",
|
||
" <td>-0.312486</td>\n",
|
||
" <td>-0.312486</td>\n",
|
||
" <td>-1.040729</td>\n",
|
||
" <td>0.866716</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.412170</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>54143560</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>36927</th>\n",
|
||
" <td>0.234246</td>\n",
|
||
" <td>0.234246</td>\n",
|
||
" <td>2.427134</td>\n",
|
||
" <td>1.219399</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>-1.009355</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3836085</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>61855</th>\n",
|
||
" <td>0.101960</td>\n",
|
||
" <td>0.101960</td>\n",
|
||
" <td>1.205148</td>\n",
|
||
" <td>0.843963</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>-0.836840</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3769804</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>15916</th>\n",
|
||
" <td>-0.386643</td>\n",
|
||
" <td>-0.386643</td>\n",
|
||
" <td>-1.147355</td>\n",
|
||
" <td>0.056145</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.509367</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3824978</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>29491</th>\n",
|
||
" <td>-0.400466</td>\n",
|
||
" <td>-0.400466</td>\n",
|
||
" <td>-1.117694</td>\n",
|
||
" <td>0.064301</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>2.233931</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3827304</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>18373</th>\n",
|
||
" <td>-0.079077</td>\n",
|
||
" <td>-0.079077</td>\n",
|
||
" <td>1.377851</td>\n",
|
||
" <td>1.612734</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>-0.509061</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3735468</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>25031</th>\n",
|
||
" <td>-0.378159</td>\n",
|
||
" <td>-0.378159</td>\n",
|
||
" <td>-0.546884</td>\n",
|
||
" <td>0.245400</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.267846</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3802041</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>35456</th>\n",
|
||
" <td>0.296300</td>\n",
|
||
" <td>0.296300</td>\n",
|
||
" <td>-1.128369</td>\n",
|
||
" <td>-0.696130</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>-1.078361</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3430406</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>14305</th>\n",
|
||
" <td>0.542404</td>\n",
|
||
" <td>0.542404</td>\n",
|
||
" <td>-1.098590</td>\n",
|
||
" <td>-0.867440</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>-1.302631</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3285300</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>72668 rows × 8 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" est_diameter_min est_diameter_max relative_velocity miss_distance \\\n",
|
||
"2639 -0.331616 -0.331616 -0.188160 0.494297 \n",
|
||
"29138 -0.312486 -0.312486 -1.040729 0.866716 \n",
|
||
"36927 0.234246 0.234246 2.427134 1.219399 \n",
|
||
"61855 0.101960 0.101960 1.205148 0.843963 \n",
|
||
"15916 -0.386643 -0.386643 -1.147355 0.056145 \n",
|
||
"... ... ... ... ... \n",
|
||
"29491 -0.400466 -0.400466 -1.117694 0.064301 \n",
|
||
"18373 -0.079077 -0.079077 1.377851 1.612734 \n",
|
||
"25031 -0.378159 -0.378159 -0.546884 0.245400 \n",
|
||
"35456 0.296300 0.296300 -1.128369 -0.696130 \n",
|
||
"14305 0.542404 0.542404 -1.098590 -0.867440 \n",
|
||
"\n",
|
||
" sentry_object absolute_magnitude hazardous id \n",
|
||
"2639 0.0 0.577785 -0.328347 3634614 \n",
|
||
"29138 0.0 0.412170 -0.328347 54143560 \n",
|
||
"36927 0.0 -1.009355 -0.328347 3836085 \n",
|
||
"61855 0.0 -0.836840 -0.328347 3769804 \n",
|
||
"15916 0.0 1.509367 -0.328347 3824978 \n",
|
||
"... ... ... ... ... \n",
|
||
"29491 0.0 2.233931 -0.328347 3827304 \n",
|
||
"18373 0.0 -0.509061 -0.328347 3735468 \n",
|
||
"25031 0.0 1.267846 -0.328347 3802041 \n",
|
||
"35456 0.0 -1.078361 -0.328347 3430406 \n",
|
||
"14305 0.0 -1.302631 -0.328347 3285300 \n",
|
||
"\n",
|
||
"[72668 rows x 8 columns]"
|
||
]
|
||
},
|
||
"execution_count": 22,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
|
||
"preprocessed_df = pd.DataFrame(\n",
|
||
" preprocessing_result,\n",
|
||
" columns=pipeline_end.get_feature_names_out(),\n",
|
||
")\n",
|
||
"\n",
|
||
"preprocessed_df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Формирование набора моделей для классификации\n",
|
||
" logistic -- логистическая регрессия\n",
|
||
"\n",
|
||
"ridge -- гребневая регрессия\n",
|
||
"\n",
|
||
"decision_tree -- дерево решений\n",
|
||
"\n",
|
||
"knn -- k-ближайших соседей\n",
|
||
"\n",
|
||
"naive_bayes -- наивный Байесовский классификатор\n",
|
||
"\n",
|
||
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
|
||
"\n",
|
||
"random_forest -- метод случайного леса (набор деревьев решений)\n",
|
||
"\n",
|
||
"mlp -- многослойный персептрон (нейронная сеть)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
|
||
"\n",
|
||
"class_models = {\n",
|
||
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
|
||
" # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
|
||
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
|
||
" \"decision_tree\": {\n",
|
||
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
|
||
" },\n",
|
||
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
|
||
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
|
||
" \"gradient_boosting\": {\n",
|
||
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
|
||
" },\n",
|
||
" \"random_forest\": {\n",
|
||
" \"model\": ensemble.RandomForestClassifier(\n",
|
||
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
|
||
" )\n",
|
||
" },\n",
|
||
" \"mlp\": {\n",
|
||
" \"model\": neural_network.MLPClassifier(\n",
|
||
" hidden_layer_sizes=(7,),\n",
|
||
" max_iter=500,\n",
|
||
" early_stopping=True,\n",
|
||
" random_state=random_state,\n",
|
||
" )\n",
|
||
" },\n",
|
||
"}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Обучение моделей на обучающем наборе данных и оценка на тестовом"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Model: logistic\n",
|
||
"Model: ridge\n",
|
||
"Model: decision_tree\n",
|
||
"Model: knn\n",
|
||
"Model: naive_bayes\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Model: gradient_boosting\n",
|
||
"Model: random_forest\n",
|
||
"Model: mlp\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"from sklearn import metrics\n",
|
||
"\n",
|
||
"for model_name in class_models.keys():\n",
|
||
" print(f\"Model: {model_name}\")\n",
|
||
" model = class_models[model_name][\"model\"]\n",
|
||
"\n",
|
||
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
|
||
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
|
||
"\n",
|
||
" y_train_predict = model_pipeline.predict(X_train)\n",
|
||
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
|
||
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
|
||
"\n",
|
||
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
||
" class_models[model_name][\"probs\"] = y_test_probs\n",
|
||
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
||
"\n",
|
||
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
|
||
" y_test, y_test_probs\n",
|
||
" )\n",
|
||
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
|
||
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
|
||
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Сводная таблица оценок качества для использованных моделей классификации"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 1200x1000 with 16 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"\n",
|
||
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
|
||
"for index, key in enumerate(class_models.keys()):\n",
|
||
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
|
||
" disp = ConfusionMatrixDisplay(\n",
|
||
" confusion_matrix=c_matrix, display_labels=[\"hazardous\", \"safe\"]\n",
|
||
" ).plot(ax=ax.flat[index])\n",
|
||
" disp.ax_.set_title(key)\n",
|
||
"\n",
|
||
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"16400 - количество истинных положительных диагнозов (True Positives), где модель правильно определила объекты как \"hazardous\".\n",
|
||
"\n",
|
||
"1768 в некоторых моделях - количество ложных отрицательных диагнозов (False Negatives), где модель неправильно определила объекты, которые на самом деле принадлежат к классу \"hazardous\", но были отнесены к классу \"safe\". \n",
|
||
"\n",
|
||
"Исходя из значений True Positives и False Negatives, можно сказать, что модель имеет высокую точность при предсказании класса \"hazardous\". В принципе, уровень ложных отрицательных результатов в некоторых моделях (1768) говорит нам о том, что существует некотрое небольшое количество примеров, которые модель пропускает.\n",
|
||
"\n",
|
||
"Точность, полнота, верность (аккуратность), F-мера"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_a22cf_row0_col0, #T_a22cf_row0_col1, #T_a22cf_row0_col2, #T_a22cf_row0_col3, #T_a22cf_row1_col0, #T_a22cf_row1_col1, #T_a22cf_row1_col2, #T_a22cf_row1_col3, #T_a22cf_row2_col0, #T_a22cf_row2_col1, #T_a22cf_row2_col2, #T_a22cf_row2_col3, #T_a22cf_row3_col0, #T_a22cf_row3_col1, #T_a22cf_row3_col2, #T_a22cf_row3_col3, #T_a22cf_row7_col2, #T_a22cf_row7_col3 {\n",
|
||
" background-color: #a8db34;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_a22cf_row0_col4, #T_a22cf_row0_col5, #T_a22cf_row0_col6, #T_a22cf_row0_col7, #T_a22cf_row1_col4, #T_a22cf_row1_col5, #T_a22cf_row1_col6, #T_a22cf_row1_col7, #T_a22cf_row2_col4, #T_a22cf_row2_col5, #T_a22cf_row2_col6, #T_a22cf_row2_col7, #T_a22cf_row3_col4, #T_a22cf_row3_col5, #T_a22cf_row3_col6, #T_a22cf_row3_col7 {\n",
|
||
" background-color: #da5a6a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_a22cf_row4_col0 {\n",
|
||
" background-color: #86d549;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_a22cf_row4_col1 {\n",
|
||
" background-color: #77d153;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_a22cf_row4_col2 {\n",
|
||
" background-color: #63cb5f;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_a22cf_row4_col3 {\n",
|
||
" background-color: #4ac16d;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_a22cf_row4_col4 {\n",
|
||
" background-color: #c03a83;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_a22cf_row4_col5 {\n",
|
||
" background-color: #b32c8e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_a22cf_row4_col6 {\n",
|
||
" background-color: #c7427c;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_a22cf_row4_col7 {\n",
|
||
" background-color: #bd3786;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_a22cf_row5_col0, #T_a22cf_row5_col1, #T_a22cf_row5_col2, #T_a22cf_row5_col3, #T_a22cf_row6_col0, #T_a22cf_row6_col1, #T_a22cf_row6_col2, #T_a22cf_row6_col3 {\n",
|
||
" background-color: #26818e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_a22cf_row5_col4, #T_a22cf_row6_col4 {\n",
|
||
" background-color: #8004a8;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_a22cf_row5_col5, #T_a22cf_row6_col5 {\n",
|
||
" background-color: #7d03a8;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_a22cf_row5_col6, #T_a22cf_row5_col7, #T_a22cf_row6_col6, #T_a22cf_row6_col7, #T_a22cf_row7_col4, #T_a22cf_row7_col5 {\n",
|
||
" background-color: #4e02a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_a22cf_row7_col0 {\n",
|
||
" background-color: #25ac82;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_a22cf_row7_col1 {\n",
|
||
" background-color: #26ad81;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_a22cf_row7_col6 {\n",
|
||
" background-color: #ac2694;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_a22cf_row7_col7 {\n",
|
||
" background-color: #ad2793;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_a22cf\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_a22cf_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
||
" <th id=\"T_a22cf_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
||
" <th id=\"T_a22cf_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
||
" <th id=\"T_a22cf_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
||
" <th id=\"T_a22cf_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
||
" <th id=\"T_a22cf_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_a22cf_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
||
" <th id=\"T_a22cf_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_a22cf_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
|
||
" <td id=\"T_a22cf_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_a22cf_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
|
||
" <td id=\"T_a22cf_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_a22cf_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
|
||
" <td id=\"T_a22cf_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_a22cf_level0_row3\" class=\"row_heading level0 row3\" >gradient_boosting</th>\n",
|
||
" <td id=\"T_a22cf_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row3_col5\" class=\"data row3 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row3_col6\" class=\"data row3 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row3_col7\" class=\"data row3 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_a22cf_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
|
||
" <td id=\"T_a22cf_row4_col0\" class=\"data row4 col0\" >0.884596</td>\n",
|
||
" <td id=\"T_a22cf_row4_col1\" class=\"data row4 col1\" >0.826374</td>\n",
|
||
" <td id=\"T_a22cf_row4_col2\" class=\"data row4 col2\" >0.744627</td>\n",
|
||
" <td id=\"T_a22cf_row4_col3\" class=\"data row4 col3\" >0.638009</td>\n",
|
||
" <td id=\"T_a22cf_row4_col4\" class=\"data row4 col4\" >0.965693</td>\n",
|
||
" <td id=\"T_a22cf_row4_col5\" class=\"data row4 col5\" >0.951728</td>\n",
|
||
" <td id=\"T_a22cf_row4_col6\" class=\"data row4 col6\" >0.808599</td>\n",
|
||
" <td id=\"T_a22cf_row4_col7\" class=\"data row4 col7\" >0.720077</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_a22cf_level0_row5\" class=\"row_heading level0 row5\" >naive_bayes</th>\n",
|
||
" <td id=\"T_a22cf_row5_col0\" class=\"data row5 col0\" >0.000000</td>\n",
|
||
" <td id=\"T_a22cf_row5_col1\" class=\"data row5 col1\" >0.000000</td>\n",
|
||
" <td id=\"T_a22cf_row5_col2\" class=\"data row5 col2\" >0.000000</td>\n",
|
||
" <td id=\"T_a22cf_row5_col3\" class=\"data row5 col3\" >0.000000</td>\n",
|
||
" <td id=\"T_a22cf_row5_col4\" class=\"data row5 col4\" >0.902681</td>\n",
|
||
" <td id=\"T_a22cf_row5_col5\" class=\"data row5 col5\" >0.902686</td>\n",
|
||
" <td id=\"T_a22cf_row5_col6\" class=\"data row5 col6\" >0.000000</td>\n",
|
||
" <td id=\"T_a22cf_row5_col7\" class=\"data row5 col7\" >0.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_a22cf_level0_row6\" class=\"row_heading level0 row6\" >mlp</th>\n",
|
||
" <td id=\"T_a22cf_row6_col0\" class=\"data row6 col0\" >0.000000</td>\n",
|
||
" <td id=\"T_a22cf_row6_col1\" class=\"data row6 col1\" >0.000000</td>\n",
|
||
" <td id=\"T_a22cf_row6_col2\" class=\"data row6 col2\" >0.000000</td>\n",
|
||
" <td id=\"T_a22cf_row6_col3\" class=\"data row6 col3\" >0.000000</td>\n",
|
||
" <td id=\"T_a22cf_row6_col4\" class=\"data row6 col4\" >0.902681</td>\n",
|
||
" <td id=\"T_a22cf_row6_col5\" class=\"data row6 col5\" >0.902686</td>\n",
|
||
" <td id=\"T_a22cf_row6_col6\" class=\"data row6 col6\" >0.000000</td>\n",
|
||
" <td id=\"T_a22cf_row6_col7\" class=\"data row6 col7\" >0.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_a22cf_level0_row7\" class=\"row_heading level0 row7\" >ridge</th>\n",
|
||
" <td id=\"T_a22cf_row7_col0\" class=\"data row7 col0\" >0.415780</td>\n",
|
||
" <td id=\"T_a22cf_row7_col1\" class=\"data row7 col1\" >0.421253</td>\n",
|
||
" <td id=\"T_a22cf_row7_col2\" class=\"data row7 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row7_col3\" class=\"data row7 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_a22cf_row7_col4\" class=\"data row7 col4\" >0.863255</td>\n",
|
||
" <td id=\"T_a22cf_row7_col5\" class=\"data row7 col5\" >0.866303</td>\n",
|
||
" <td id=\"T_a22cf_row7_col6\" class=\"data row7 col6\" >0.587351</td>\n",
|
||
" <td id=\"T_a22cf_row7_col7\" class=\"data row7 col7\" >0.592791</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x1b3e1e74950>"
|
||
]
|
||
},
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
||
" [\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" \"Accuracy_train\",\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_train\",\n",
|
||
" \"F1_test\",\n",
|
||
" ]\n",
|
||
"]\n",
|
||
"class_metrics.sort_values(\n",
|
||
" by=\"Accuracy_test\", ascending=False\n",
|
||
").style.background_gradient(\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Все модели в данной выборке, а именно логистическая регрессия, ридж-регрессия, дерево решений, KNN, наивный байесовский классификатор, градиентный бустинг, случайный лес и многослойный перцептрон (MLP) демонстрируют неплохие значения по всем метрикам на обучающих и тестовых наборах данных.\n",
|
||
"\n",
|
||
"Модели Naive Bayes и MLP не так эффективны по сравнению с другими, но в некоторых метриках показывают высокие результаты. \n",
|
||
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_46430_row0_col0, #T_46430_row0_col1, #T_46430_row1_col0, #T_46430_row1_col1, #T_46430_row2_col0, #T_46430_row2_col1, #T_46430_row3_col0, #T_46430_row3_col1 {\n",
|
||
" background-color: #a8db34;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_46430_row0_col2, #T_46430_row0_col3, #T_46430_row0_col4, #T_46430_row1_col2, #T_46430_row1_col3, #T_46430_row1_col4, #T_46430_row2_col2, #T_46430_row2_col3, #T_46430_row2_col4, #T_46430_row3_col2, #T_46430_row3_col3, #T_46430_row3_col4 {\n",
|
||
" background-color: #da5a6a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_46430_row4_col0, #T_46430_row6_col1, #T_46430_row7_col1 {\n",
|
||
" background-color: #26818e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_46430_row4_col1 {\n",
|
||
" background-color: #40bd72;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_46430_row4_col2 {\n",
|
||
" background-color: #d9586a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_46430_row4_col3, #T_46430_row6_col2 {\n",
|
||
" background-color: #a51f99;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_46430_row4_col4 {\n",
|
||
" background-color: #ae2892;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_46430_row5_col0 {\n",
|
||
" background-color: #4ac16d;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_46430_row5_col1 {\n",
|
||
" background-color: #5cc863;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_46430_row5_col2 {\n",
|
||
" background-color: #d14e72;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_46430_row5_col3 {\n",
|
||
" background-color: #ba3388;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_46430_row5_col4 {\n",
|
||
" background-color: #bb3488;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_46430_row6_col0, #T_46430_row7_col0 {\n",
|
||
" background-color: #1e9d89;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_46430_row6_col3, #T_46430_row6_col4, #T_46430_row7_col2, #T_46430_row7_col3, #T_46430_row7_col4 {\n",
|
||
" background-color: #4e02a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_46430\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_46430_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_46430_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
||
" <th id=\"T_46430_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
||
" <th id=\"T_46430_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
||
" <th id=\"T_46430_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_46430_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
|
||
" <td id=\"T_46430_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_46430_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_46430_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_46430_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_46430_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_46430_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
|
||
" <td id=\"T_46430_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_46430_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_46430_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_46430_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_46430_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_46430_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
|
||
" <td id=\"T_46430_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_46430_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_46430_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_46430_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_46430_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_46430_level0_row3\" class=\"row_heading level0 row3\" >gradient_boosting</th>\n",
|
||
" <td id=\"T_46430_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_46430_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_46430_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_46430_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_46430_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_46430_level0_row4\" class=\"row_heading level0 row4\" >ridge</th>\n",
|
||
" <td id=\"T_46430_row4_col0\" class=\"data row4 col0\" >0.866303</td>\n",
|
||
" <td id=\"T_46430_row4_col1\" class=\"data row4 col1\" >0.592791</td>\n",
|
||
" <td id=\"T_46430_row4_col2\" class=\"data row4 col2\" >0.995675</td>\n",
|
||
" <td id=\"T_46430_row4_col3\" class=\"data row4 col3\" >0.528180</td>\n",
|
||
" <td id=\"T_46430_row4_col4\" class=\"data row4 col4\" >0.599051</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_46430_level0_row5\" class=\"row_heading level0 row5\" >knn</th>\n",
|
||
" <td id=\"T_46430_row5_col0\" class=\"data row5 col0\" >0.951728</td>\n",
|
||
" <td id=\"T_46430_row5_col1\" class=\"data row5 col1\" >0.720077</td>\n",
|
||
" <td id=\"T_46430_row5_col2\" class=\"data row5 col2\" >0.953405</td>\n",
|
||
" <td id=\"T_46430_row5_col3\" class=\"data row5 col3\" >0.694141</td>\n",
|
||
" <td id=\"T_46430_row5_col4\" class=\"data row5 col4\" >0.701100</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_46430_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
|
||
" <td id=\"T_46430_row6_col0\" class=\"data row6 col0\" >0.902686</td>\n",
|
||
" <td id=\"T_46430_row6_col1\" class=\"data row6 col1\" >0.000000</td>\n",
|
||
" <td id=\"T_46430_row6_col2\" class=\"data row6 col2\" >0.766341</td>\n",
|
||
" <td id=\"T_46430_row6_col3\" class=\"data row6 col3\" >0.000000</td>\n",
|
||
" <td id=\"T_46430_row6_col4\" class=\"data row6 col4\" >0.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_46430_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
|
||
" <td id=\"T_46430_row7_col0\" class=\"data row7 col0\" >0.902686</td>\n",
|
||
" <td id=\"T_46430_row7_col1\" class=\"data row7 col1\" >0.000000</td>\n",
|
||
" <td id=\"T_46430_row7_col2\" class=\"data row7 col2\" >0.500000</td>\n",
|
||
" <td id=\"T_46430_row7_col3\" class=\"data row7 col3\" >0.000000</td>\n",
|
||
" <td id=\"T_46430_row7_col4\" class=\"data row7 col4\" >0.000000</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x1b3dda0e660>"
|
||
]
|
||
},
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
||
" [\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_test\",\n",
|
||
" \"ROC_AUC_test\",\n",
|
||
" \"Cohen_kappa_test\",\n",
|
||
" \"MCC_test\",\n",
|
||
" ]\n",
|
||
"]\n",
|
||
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" subset=[\n",
|
||
" \"ROC_AUC_test\",\n",
|
||
" \"MCC_test\",\n",
|
||
" \"Cohen_kappa_test\",\n",
|
||
" ],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Схожий вывод можно сделать и для следующих метрик: Accuracy, F1, ROC AUC, Cohen's Kappa и MCC. Все модели, кроме Naive Bayes и MLP, указывают на хорошо-развитую способность к выделению классов"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'logistic'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
|
||
"\n",
|
||
"display(best_model)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Вывод данных с ошибкой предсказания для оценки"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 29,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'Error items count: 0'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>id</th>\n",
|
||
" <th>Predicted</th>\n",
|
||
" <th>name</th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>orbiting_body</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
"Empty DataFrame\n",
|
||
"Columns: [id, Predicted, name, est_diameter_min, est_diameter_max, relative_velocity, miss_distance, orbiting_body, sentry_object, absolute_magnitude, hazardous]\n",
|
||
"Index: []"
|
||
]
|
||
},
|
||
"execution_count": 29,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"preprocessing_result = pipeline_end.transform(X_test)\n",
|
||
"preprocessed_df = pd.DataFrame(\n",
|
||
" preprocessing_result,\n",
|
||
" columns=pipeline_end.get_feature_names_out(),\n",
|
||
")\n",
|
||
"\n",
|
||
"y_pred = class_models[best_model][\"preds\"]\n",
|
||
"\n",
|
||
"error_index = y_test[y_test[\"hazardous\"] != y_pred].index.tolist()\n",
|
||
"display(f\"Error items count: {len(error_index)}\")\n",
|
||
"\n",
|
||
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
|
||
"error_df = X_test.loc[error_index].copy()\n",
|
||
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
|
||
"error_df.sort_index()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Пример использования обученной модели (конвейера) для предсказания\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 30,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>id</th>\n",
|
||
" <th>name</th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>orbiting_body</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>67305</th>\n",
|
||
" <td>3774018</td>\n",
|
||
" <td>(2017 HF1)</td>\n",
|
||
" <td>0.084053</td>\n",
|
||
" <td>0.187949</td>\n",
|
||
" <td>53291.016226</td>\n",
|
||
" <td>68625911.198806</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>22.5</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" id name est_diameter_min est_diameter_max \\\n",
|
||
"67305 3774018 (2017 HF1) 0.084053 0.187949 \n",
|
||
"\n",
|
||
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
|
||
"67305 53291.016226 68625911.198806 Earth False \n",
|
||
"\n",
|
||
" absolute_magnitude hazardous \n",
|
||
"67305 22.5 False "
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" <th>id</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>67305</th>\n",
|
||
" <td>-0.140818</td>\n",
|
||
" <td>-0.140818</td>\n",
|
||
" <td>0.207258</td>\n",
|
||
" <td>1.410653</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>-0.353797</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3774018.0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" est_diameter_min est_diameter_max relative_velocity miss_distance \\\n",
|
||
"67305 -0.140818 -0.140818 0.207258 1.410653 \n",
|
||
"\n",
|
||
" sentry_object absolute_magnitude hazardous id \n",
|
||
"67305 0.0 -0.353797 -0.328347 3774018.0 "
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'predicted: False (proba: [9.99855425e-01 1.44575476e-04])'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'real: 0'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"model = class_models[best_model][\"pipeline\"]\n",
|
||
"\n",
|
||
"example_id = 67305\n",
|
||
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
|
||
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
|
||
"display(test)\n",
|
||
"display(test_preprocessed)\n",
|
||
"result_proba = model.predict_proba(test)[0]\n",
|
||
"result = model.predict(test)[0]\n",
|
||
"real = int(y_test.loc[example_id].values[0])\n",
|
||
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
|
||
"display(f\"real: {real}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Подбор гиперпараметров методом поиска по сетке "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 31,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
|
||
" _data = np.array(data, dtype=dtype, copy=copy,\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'model__criterion': 'gini',\n",
|
||
" 'model__max_depth': 5,\n",
|
||
" 'model__max_features': 'sqrt',\n",
|
||
" 'model__n_estimators': 50}"
|
||
]
|
||
},
|
||
"execution_count": 31,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.model_selection import GridSearchCV\n",
|
||
"\n",
|
||
"optimized_model_type = \"random_forest\"\n",
|
||
"\n",
|
||
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
|
||
"\n",
|
||
"param_grid = {\n",
|
||
" \"model__n_estimators\": [10, 50, 100],\n",
|
||
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
|
||
" \"model__max_depth\": [5, 7, 10],\n",
|
||
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
|
||
"}\n",
|
||
"\n",
|
||
"gs_optomizer = GridSearchCV(\n",
|
||
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
|
||
")\n",
|
||
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
|
||
"gs_optomizer.best_params_"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Обучение модели с новыми гиперпараметрами"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 38,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.pipeline import Pipeline\n",
|
||
"from sklearn.preprocessing import StandardScaler\n",
|
||
"from sklearn.compose import ColumnTransformer\n",
|
||
"from sklearn.ensemble import RandomForestClassifier\n",
|
||
"import numpy as np\n",
|
||
"from sklearn import metrics\n",
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"\n",
|
||
"# Определяем числовые признаки\n",
|
||
"numeric_features = X_train.select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
|
||
"\n",
|
||
"# Установка random_state\n",
|
||
"random_state = 42\n",
|
||
"\n",
|
||
"# Определение трансформера\n",
|
||
"pipeline_end = ColumnTransformer([\n",
|
||
" ('numeric', StandardScaler(), numeric_features),\n",
|
||
" # Добавьте другие трансформеры, если требуется\n",
|
||
"])\n",
|
||
"\n",
|
||
"# Объявление модели\n",
|
||
"optimized_model = RandomForestClassifier(\n",
|
||
" random_state=random_state,\n",
|
||
" criterion=\"gini\",\n",
|
||
" max_depth=5,\n",
|
||
" max_features=\"sqrt\",\n",
|
||
" n_estimators=50,\n",
|
||
")\n",
|
||
"\n",
|
||
"# Создание пайплайна с корректными шагами\n",
|
||
"result = {}\n",
|
||
"\n",
|
||
"# Обучение модели\n",
|
||
"result[\"pipeline\"] = Pipeline([\n",
|
||
" (\"pipeline\", pipeline_end),\n",
|
||
" (\"model\", optimized_model)\n",
|
||
"]).fit(X_train, y_train.values.ravel())\n",
|
||
"\n",
|
||
"# Прогнозирование и расчет метрик\n",
|
||
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
|
||
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
|
||
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
|
||
"\n",
|
||
"# Метрики для оценки модели\n",
|
||
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
|
||
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
|
||
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
|
||
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
|
||
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
|
||
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
|
||
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
|
||
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
|
||
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
|
||
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
|
||
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
|
||
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Формирование данных для оценки старой и новой версии модели"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 39,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
|
||
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
||
" data=class_models[optimized_model_type]\n",
|
||
")\n",
|
||
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
||
" data=result\n",
|
||
")\n",
|
||
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
|
||
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Оценка параметров старой и новой модели"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 40,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_55496_row0_col0, #T_55496_row0_col1, #T_55496_row0_col2, #T_55496_row0_col3 {\n",
|
||
" background-color: #a8db34;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_55496_row0_col4, #T_55496_row0_col5, #T_55496_row0_col6, #T_55496_row0_col7 {\n",
|
||
" background-color: #da5a6a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_55496_row1_col0, #T_55496_row1_col1, #T_55496_row1_col2, #T_55496_row1_col3 {\n",
|
||
" background-color: #26818e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_55496_row1_col4, #T_55496_row1_col5, #T_55496_row1_col6, #T_55496_row1_col7 {\n",
|
||
" background-color: #4e02a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_55496\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_55496_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
||
" <th id=\"T_55496_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
||
" <th id=\"T_55496_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
||
" <th id=\"T_55496_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
||
" <th id=\"T_55496_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
||
" <th id=\"T_55496_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_55496_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
||
" <th id=\"T_55496_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th class=\"index_name level0\" >Name</th>\n",
|
||
" <th class=\"blank col0\" > </th>\n",
|
||
" <th class=\"blank col1\" > </th>\n",
|
||
" <th class=\"blank col2\" > </th>\n",
|
||
" <th class=\"blank col3\" > </th>\n",
|
||
" <th class=\"blank col4\" > </th>\n",
|
||
" <th class=\"blank col5\" > </th>\n",
|
||
" <th class=\"blank col6\" > </th>\n",
|
||
" <th class=\"blank col7\" > </th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_55496_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
||
" <td id=\"T_55496_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_55496_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_55496_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_55496_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_55496_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_55496_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_55496_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_55496_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_55496_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
||
" <td id=\"T_55496_row1_col0\" class=\"data row1 col0\" >0.833191</td>\n",
|
||
" <td id=\"T_55496_row1_col1\" class=\"data row1 col1\" >0.862500</td>\n",
|
||
" <td id=\"T_55496_row1_col2\" class=\"data row1 col2\" >0.138433</td>\n",
|
||
" <td id=\"T_55496_row1_col3\" class=\"data row1 col3\" >0.156109</td>\n",
|
||
" <td id=\"T_55496_row1_col4\" class=\"data row1 col4\" >0.913456</td>\n",
|
||
" <td id=\"T_55496_row1_col5\" class=\"data row1 col5\" >0.915456</td>\n",
|
||
" <td id=\"T_55496_row1_col6\" class=\"data row1 col6\" >0.237420</td>\n",
|
||
" <td id=\"T_55496_row1_col7\" class=\"data row1 col7\" >0.264368</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x1b3e1be0920>"
|
||
]
|
||
},
|
||
"execution_count": 40,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"optimized_metrics[\n",
|
||
" [\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" \"Accuracy_train\",\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_train\",\n",
|
||
" \"F1_test\",\n",
|
||
" ]\n",
|
||
"].style.background_gradient(\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 41,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_36483_row0_col0, #T_36483_row0_col1 {\n",
|
||
" background-color: #a8db34;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_36483_row0_col2, #T_36483_row0_col3, #T_36483_row0_col4 {\n",
|
||
" background-color: #da5a6a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_36483_row1_col0, #T_36483_row1_col1 {\n",
|
||
" background-color: #26818e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_36483_row1_col2, #T_36483_row1_col3, #T_36483_row1_col4 {\n",
|
||
" background-color: #4e02a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_36483\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_36483_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_36483_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
||
" <th id=\"T_36483_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
||
" <th id=\"T_36483_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
||
" <th id=\"T_36483_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th class=\"index_name level0\" >Name</th>\n",
|
||
" <th class=\"blank col0\" > </th>\n",
|
||
" <th class=\"blank col1\" > </th>\n",
|
||
" <th class=\"blank col2\" > </th>\n",
|
||
" <th class=\"blank col3\" > </th>\n",
|
||
" <th class=\"blank col4\" > </th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_36483_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
||
" <td id=\"T_36483_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_36483_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_36483_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_36483_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_36483_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_36483_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
||
" <td id=\"T_36483_row1_col0\" class=\"data row1 col0\" >0.915456</td>\n",
|
||
" <td id=\"T_36483_row1_col1\" class=\"data row1 col1\" >0.264368</td>\n",
|
||
" <td id=\"T_36483_row1_col2\" class=\"data row1 col2\" >0.927493</td>\n",
|
||
" <td id=\"T_36483_row1_col3\" class=\"data row1 col3\" >0.241751</td>\n",
|
||
" <td id=\"T_36483_row1_col4\" class=\"data row1 col4\" >0.345694</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x1b3e1be2de0>"
|
||
]
|
||
},
|
||
"execution_count": 41,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"optimized_metrics[\n",
|
||
" [\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_test\",\n",
|
||
" \"ROC_AUC_test\",\n",
|
||
" \"Cohen_kappa_test\",\n",
|
||
" \"MCC_test\",\n",
|
||
" ]\n",
|
||
"].style.background_gradient(\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" subset=[\n",
|
||
" \"ROC_AUC_test\",\n",
|
||
" \"MCC_test\",\n",
|
||
" \"Cohen_kappa_test\",\n",
|
||
" ],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 42,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 1000x400 with 4 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
|
||
")\n",
|
||
"\n",
|
||
"for index in range(0, len(optimized_metrics)):\n",
|
||
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
|
||
" disp = ConfusionMatrixDisplay(\n",
|
||
" confusion_matrix=c_matrix, display_labels=[\"hazardous\", \"safe\"]\n",
|
||
" ).plot(ax=ax.flat[index])\n",
|
||
"\n",
|
||
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"В желтых квадрате мы наблюдаем значение 16400, что обозначает количество правильно классифицированных объектов, отнесенных к классу \"hazardsous\". Это свидетельствует о том, что модель успешно идентифицирует объекты этого класса, минимизируя количество ложных положительных срабатываний.\n",
|
||
"\n",
|
||
"В фиолетвом квадрате значение 276 указывает на количество правильно классифицированных объектов, отнесенных к классу \"More\". Это является показателем не такой высокой точности модели в определении объектов данного класса."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Определение достижимого уровня качества модели для второй задачи (задача регрессии)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 201,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"(5000, 6)\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>id</th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>3943344</td>\n",
|
||
" <td>0.024241</td>\n",
|
||
" <td>0.054205</td>\n",
|
||
" <td>22148.962596</td>\n",
|
||
" <td>5.028574e+07</td>\n",
|
||
" <td>25.20</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>3879239</td>\n",
|
||
" <td>0.012722</td>\n",
|
||
" <td>0.028447</td>\n",
|
||
" <td>26477.211836</td>\n",
|
||
" <td>1.683201e+06</td>\n",
|
||
" <td>26.60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>3879244</td>\n",
|
||
" <td>0.013322</td>\n",
|
||
" <td>0.029788</td>\n",
|
||
" <td>33770.201397</td>\n",
|
||
" <td>3.943220e+06</td>\n",
|
||
" <td>26.50</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>2481965</td>\n",
|
||
" <td>0.193444</td>\n",
|
||
" <td>0.432554</td>\n",
|
||
" <td>43599.575296</td>\n",
|
||
" <td>7.346837e+07</td>\n",
|
||
" <td>20.69</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>3789471</td>\n",
|
||
" <td>0.044112</td>\n",
|
||
" <td>0.098637</td>\n",
|
||
" <td>36398.080883</td>\n",
|
||
" <td>6.352916e+07</td>\n",
|
||
" <td>23.90</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4995</th>\n",
|
||
" <td>3468663</td>\n",
|
||
" <td>0.006677</td>\n",
|
||
" <td>0.014929</td>\n",
|
||
" <td>20300.398051</td>\n",
|
||
" <td>1.700006e+06</td>\n",
|
||
" <td>28.00</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4996</th>\n",
|
||
" <td>3620670</td>\n",
|
||
" <td>0.105817</td>\n",
|
||
" <td>0.236614</td>\n",
|
||
" <td>36514.062162</td>\n",
|
||
" <td>6.945396e+07</td>\n",
|
||
" <td>22.00</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4997</th>\n",
|
||
" <td>3562321</td>\n",
|
||
" <td>0.192555</td>\n",
|
||
" <td>0.430566</td>\n",
|
||
" <td>68895.907750</td>\n",
|
||
" <td>5.209557e+07</td>\n",
|
||
" <td>20.70</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4998</th>\n",
|
||
" <td>3440771</td>\n",
|
||
" <td>0.253837</td>\n",
|
||
" <td>0.567597</td>\n",
|
||
" <td>61336.513568</td>\n",
|
||
" <td>5.037204e+07</td>\n",
|
||
" <td>20.10</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4999</th>\n",
|
||
" <td>54065901</td>\n",
|
||
" <td>0.015295</td>\n",
|
||
" <td>0.034201</td>\n",
|
||
" <td>18389.028188</td>\n",
|
||
" <td>5.627145e+07</td>\n",
|
||
" <td>26.20</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>5000 rows × 6 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" id est_diameter_min est_diameter_max relative_velocity \\\n",
|
||
"0 3943344 0.024241 0.054205 22148.962596 \n",
|
||
"1 3879239 0.012722 0.028447 26477.211836 \n",
|
||
"2 3879244 0.013322 0.029788 33770.201397 \n",
|
||
"3 2481965 0.193444 0.432554 43599.575296 \n",
|
||
"4 3789471 0.044112 0.098637 36398.080883 \n",
|
||
"... ... ... ... ... \n",
|
||
"4995 3468663 0.006677 0.014929 20300.398051 \n",
|
||
"4996 3620670 0.105817 0.236614 36514.062162 \n",
|
||
"4997 3562321 0.192555 0.430566 68895.907750 \n",
|
||
"4998 3440771 0.253837 0.567597 61336.513568 \n",
|
||
"4999 54065901 0.015295 0.034201 18389.028188 \n",
|
||
"\n",
|
||
" miss_distance absolute_magnitude \n",
|
||
"0 5.028574e+07 25.20 \n",
|
||
"1 1.683201e+06 26.60 \n",
|
||
"2 3.943220e+06 26.50 \n",
|
||
"3 7.346837e+07 20.69 \n",
|
||
"4 6.352916e+07 23.90 \n",
|
||
"... ... ... \n",
|
||
"4995 1.700006e+06 28.00 \n",
|
||
"4996 6.945396e+07 22.00 \n",
|
||
"4997 5.209557e+07 20.70 \n",
|
||
"4998 5.037204e+07 20.10 \n",
|
||
"4999 5.627145e+07 26.20 \n",
|
||
"\n",
|
||
"[5000 rows x 6 columns]"
|
||
]
|
||
},
|
||
"execution_count": 201,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import seaborn as sns\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"from sklearn import set_config\n",
|
||
"\n",
|
||
"random_state=42\n",
|
||
"set_config(transform_output=\"pandas\")\n",
|
||
"df = pd.read_csv(\".//static//csv//neo.csv\")\n",
|
||
"# Удаление столбцов \"sentry_object\" и \"hazardous\"\n",
|
||
"df = df.drop(columns=[\"sentry_object\", \"hazardous\", \"orbiting_body\", \"name\"])\n",
|
||
"\n",
|
||
"# Ограничение количества записей до 5,000\n",
|
||
"df = df.sample(n=5000, random_state=random_state).reset_index(drop=True)\n",
|
||
"\n",
|
||
"# Проверка итогового DataFrame\n",
|
||
"print(df.shape) # Убедитесь, что размер 5,000 строк\n",
|
||
"df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 202,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" est_diameter_min est_diameter_max relative_velocity miss_distance \\\n",
|
||
"0 1.198271 2.679415 13569.249224 5.483974e+07 \n",
|
||
"1 0.265800 0.594347 73588.726663 6.143813e+07 \n",
|
||
"2 0.722030 1.614507 114258.692129 4.979872e+07 \n",
|
||
"3 0.096506 0.215794 24764.303138 2.543497e+07 \n",
|
||
"4 0.255009 0.570217 42737.733765 4.627557e+07 \n",
|
||
"\n",
|
||
" impact_damage_index \n",
|
||
"0 0.000480 \n",
|
||
"1 0.000515 \n",
|
||
"2 0.002680 \n",
|
||
"3 0.000152 \n",
|
||
"4 0.000381 \n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"# Загрузка данных (замените путь на актуальный, если требуется)\n",
|
||
"df = pd.read_csv(\".//static//csv//neo.csv\")\n",
|
||
"\n",
|
||
"# Убедитесь, что столбцы в данных содержат необходимые характеристики\n",
|
||
"required_columns = [\"est_diameter_min\", \"est_diameter_max\", \"relative_velocity\", \"miss_distance\"]\n",
|
||
"missing_columns = [col for col in required_columns if col not in df.columns]\n",
|
||
"if missing_columns:\n",
|
||
" raise ValueError(f\"Отсутствуют столбцы: {missing_columns}\")\n",
|
||
"\n",
|
||
"# Создание переменной \"impact_damage_index\"\n",
|
||
"# Формула, используемая ниже, условная и может быть скорректирована в зависимости от анализа\n",
|
||
"# Пример: чем больше средний диаметр и скорость, тем выше ущерб. Чем больше расстояние, тем ниже ущерб.\n",
|
||
"df[\"impact_damage_index\"] = (\n",
|
||
" (df[\"est_diameter_min\"] + df[\"est_diameter_max\"]) / 2 # Средний диаметр\n",
|
||
" * df[\"relative_velocity\"] # Скорость\n",
|
||
" / df[\"miss_distance\"] # Обратная зависимость от расстояния\n",
|
||
")\n",
|
||
"\n",
|
||
"# Проверка новых данных\n",
|
||
"print(df[[\"est_diameter_min\", \"est_diameter_max\", \"relative_velocity\", \"miss_distance\", \"impact_damage_index\"]].head())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи регрессии "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 203,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>id</th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>35538</th>\n",
|
||
" <td>3826685</td>\n",
|
||
" <td>0.038420</td>\n",
|
||
" <td>0.085909</td>\n",
|
||
" <td>91103.489666</td>\n",
|
||
" <td>6.350550e+07</td>\n",
|
||
" <td>24.20</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>40393</th>\n",
|
||
" <td>2277830</td>\n",
|
||
" <td>0.192555</td>\n",
|
||
" <td>0.430566</td>\n",
|
||
" <td>28359.611312</td>\n",
|
||
" <td>2.868167e+07</td>\n",
|
||
" <td>20.70</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>58540</th>\n",
|
||
" <td>3638201</td>\n",
|
||
" <td>0.004619</td>\n",
|
||
" <td>0.010329</td>\n",
|
||
" <td>107351.426865</td>\n",
|
||
" <td>5.388098e+04</td>\n",
|
||
" <td>28.80</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>61670</th>\n",
|
||
" <td>3836282</td>\n",
|
||
" <td>0.015295</td>\n",
|
||
" <td>0.034201</td>\n",
|
||
" <td>21423.536884</td>\n",
|
||
" <td>5.103884e+07</td>\n",
|
||
" <td>26.20</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>11435</th>\n",
|
||
" <td>3802002</td>\n",
|
||
" <td>0.011603</td>\n",
|
||
" <td>0.025944</td>\n",
|
||
" <td>69856.053840</td>\n",
|
||
" <td>7.360836e+07</td>\n",
|
||
" <td>26.80</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>6265</th>\n",
|
||
" <td>2530151</td>\n",
|
||
" <td>0.211132</td>\n",
|
||
" <td>0.472106</td>\n",
|
||
" <td>88209.754856</td>\n",
|
||
" <td>4.034289e+07</td>\n",
|
||
" <td>20.50</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>54886</th>\n",
|
||
" <td>3831736</td>\n",
|
||
" <td>0.035039</td>\n",
|
||
" <td>0.078350</td>\n",
|
||
" <td>58758.452153</td>\n",
|
||
" <td>4.389994e+06</td>\n",
|
||
" <td>24.40</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>76820</th>\n",
|
||
" <td>2512234</td>\n",
|
||
" <td>0.211132</td>\n",
|
||
" <td>0.472106</td>\n",
|
||
" <td>52355.509176</td>\n",
|
||
" <td>4.380532e+07</td>\n",
|
||
" <td>20.50</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>860</th>\n",
|
||
" <td>54054466</td>\n",
|
||
" <td>0.282199</td>\n",
|
||
" <td>0.631015</td>\n",
|
||
" <td>50527.379563</td>\n",
|
||
" <td>5.837007e+07</td>\n",
|
||
" <td>19.87</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>15795</th>\n",
|
||
" <td>3773929</td>\n",
|
||
" <td>0.075258</td>\n",
|
||
" <td>0.168283</td>\n",
|
||
" <td>22527.647871</td>\n",
|
||
" <td>2.281469e+07</td>\n",
|
||
" <td>22.74</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>72668 rows × 6 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" id est_diameter_min est_diameter_max relative_velocity \\\n",
|
||
"35538 3826685 0.038420 0.085909 91103.489666 \n",
|
||
"40393 2277830 0.192555 0.430566 28359.611312 \n",
|
||
"58540 3638201 0.004619 0.010329 107351.426865 \n",
|
||
"61670 3836282 0.015295 0.034201 21423.536884 \n",
|
||
"11435 3802002 0.011603 0.025944 69856.053840 \n",
|
||
"... ... ... ... ... \n",
|
||
"6265 2530151 0.211132 0.472106 88209.754856 \n",
|
||
"54886 3831736 0.035039 0.078350 58758.452153 \n",
|
||
"76820 2512234 0.211132 0.472106 52355.509176 \n",
|
||
"860 54054466 0.282199 0.631015 50527.379563 \n",
|
||
"15795 3773929 0.075258 0.168283 22527.647871 \n",
|
||
"\n",
|
||
" miss_distance absolute_magnitude \n",
|
||
"35538 6.350550e+07 24.20 \n",
|
||
"40393 2.868167e+07 20.70 \n",
|
||
"58540 5.388098e+04 28.80 \n",
|
||
"61670 5.103884e+07 26.20 \n",
|
||
"11435 7.360836e+07 26.80 \n",
|
||
"... ... ... \n",
|
||
"6265 4.034289e+07 20.50 \n",
|
||
"54886 4.389994e+06 24.40 \n",
|
||
"76820 4.380532e+07 20.50 \n",
|
||
"860 5.837007e+07 19.87 \n",
|
||
"15795 2.281469e+07 22.74 \n",
|
||
"\n",
|
||
"[72668 rows x 6 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>impact_damage_index</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>35538</th>\n",
|
||
" <td>0.000089</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>40393</th>\n",
|
||
" <td>0.000308</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>58540</th>\n",
|
||
" <td>0.014891</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>61670</th>\n",
|
||
" <td>0.000010</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>11435</th>\n",
|
||
" <td>0.000018</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>6265</th>\n",
|
||
" <td>0.000747</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>54886</th>\n",
|
||
" <td>0.000759</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>76820</th>\n",
|
||
" <td>0.000408</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>860</th>\n",
|
||
" <td>0.000395</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>15795</th>\n",
|
||
" <td>0.000120</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>72668 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" impact_damage_index\n",
|
||
"35538 0.000089\n",
|
||
"40393 0.000308\n",
|
||
"58540 0.014891\n",
|
||
"61670 0.000010\n",
|
||
"11435 0.000018\n",
|
||
"... ...\n",
|
||
"6265 0.000747\n",
|
||
"54886 0.000759\n",
|
||
"76820 0.000408\n",
|
||
"860 0.000395\n",
|
||
"15795 0.000120\n",
|
||
"\n",
|
||
"[72668 rows x 1 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>id</th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>20406</th>\n",
|
||
" <td>3943344</td>\n",
|
||
" <td>0.024241</td>\n",
|
||
" <td>0.054205</td>\n",
|
||
" <td>22148.962596</td>\n",
|
||
" <td>5.028574e+07</td>\n",
|
||
" <td>25.20</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>74443</th>\n",
|
||
" <td>3879239</td>\n",
|
||
" <td>0.012722</td>\n",
|
||
" <td>0.028447</td>\n",
|
||
" <td>26477.211836</td>\n",
|
||
" <td>1.683201e+06</td>\n",
|
||
" <td>26.60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>74306</th>\n",
|
||
" <td>3879244</td>\n",
|
||
" <td>0.013322</td>\n",
|
||
" <td>0.029788</td>\n",
|
||
" <td>33770.201397</td>\n",
|
||
" <td>3.943220e+06</td>\n",
|
||
" <td>26.50</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>45943</th>\n",
|
||
" <td>2481965</td>\n",
|
||
" <td>0.193444</td>\n",
|
||
" <td>0.432554</td>\n",
|
||
" <td>43599.575296</td>\n",
|
||
" <td>7.346837e+07</td>\n",
|
||
" <td>20.69</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>62859</th>\n",
|
||
" <td>3789471</td>\n",
|
||
" <td>0.044112</td>\n",
|
||
" <td>0.098637</td>\n",
|
||
" <td>36398.080883</td>\n",
|
||
" <td>6.352916e+07</td>\n",
|
||
" <td>23.90</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>51634</th>\n",
|
||
" <td>3694131</td>\n",
|
||
" <td>0.008801</td>\n",
|
||
" <td>0.019681</td>\n",
|
||
" <td>57414.305699</td>\n",
|
||
" <td>1.987273e+07</td>\n",
|
||
" <td>27.40</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>85083</th>\n",
|
||
" <td>54235475</td>\n",
|
||
" <td>0.024920</td>\n",
|
||
" <td>0.055724</td>\n",
|
||
" <td>50882.935767</td>\n",
|
||
" <td>3.119646e+07</td>\n",
|
||
" <td>25.14</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>38905</th>\n",
|
||
" <td>3775176</td>\n",
|
||
" <td>0.008405</td>\n",
|
||
" <td>0.018795</td>\n",
|
||
" <td>24954.754212</td>\n",
|
||
" <td>1.111942e+07</td>\n",
|
||
" <td>27.50</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>16144</th>\n",
|
||
" <td>2434734</td>\n",
|
||
" <td>0.265800</td>\n",
|
||
" <td>0.594347</td>\n",
|
||
" <td>57455.404666</td>\n",
|
||
" <td>8.501684e+06</td>\n",
|
||
" <td>20.00</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>54508</th>\n",
|
||
" <td>3170208</td>\n",
|
||
" <td>0.023150</td>\n",
|
||
" <td>0.051765</td>\n",
|
||
" <td>72602.093427</td>\n",
|
||
" <td>4.624727e+07</td>\n",
|
||
" <td>25.30</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>18168 rows × 6 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" id est_diameter_min est_diameter_max relative_velocity \\\n",
|
||
"20406 3943344 0.024241 0.054205 22148.962596 \n",
|
||
"74443 3879239 0.012722 0.028447 26477.211836 \n",
|
||
"74306 3879244 0.013322 0.029788 33770.201397 \n",
|
||
"45943 2481965 0.193444 0.432554 43599.575296 \n",
|
||
"62859 3789471 0.044112 0.098637 36398.080883 \n",
|
||
"... ... ... ... ... \n",
|
||
"51634 3694131 0.008801 0.019681 57414.305699 \n",
|
||
"85083 54235475 0.024920 0.055724 50882.935767 \n",
|
||
"38905 3775176 0.008405 0.018795 24954.754212 \n",
|
||
"16144 2434734 0.265800 0.594347 57455.404666 \n",
|
||
"54508 3170208 0.023150 0.051765 72602.093427 \n",
|
||
"\n",
|
||
" miss_distance absolute_magnitude \n",
|
||
"20406 5.028574e+07 25.20 \n",
|
||
"74443 1.683201e+06 26.60 \n",
|
||
"74306 3.943220e+06 26.50 \n",
|
||
"45943 7.346837e+07 20.69 \n",
|
||
"62859 6.352916e+07 23.90 \n",
|
||
"... ... ... \n",
|
||
"51634 1.987273e+07 27.40 \n",
|
||
"85083 3.119646e+07 25.14 \n",
|
||
"38905 1.111942e+07 27.50 \n",
|
||
"16144 8.501684e+06 20.00 \n",
|
||
"54508 4.624727e+07 25.30 \n",
|
||
"\n",
|
||
"[18168 rows x 6 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>impact_damage_index</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>20406</th>\n",
|
||
" <td>0.000017</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>74443</th>\n",
|
||
" <td>0.000324</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>74306</th>\n",
|
||
" <td>0.000185</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>45943</th>\n",
|
||
" <td>0.000186</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>62859</th>\n",
|
||
" <td>0.000041</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>51634</th>\n",
|
||
" <td>0.000041</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>85083</th>\n",
|
||
" <td>0.000066</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>38905</th>\n",
|
||
" <td>0.000031</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>16144</th>\n",
|
||
" <td>0.002906</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>54508</th>\n",
|
||
" <td>0.000059</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>18168 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" impact_damage_index\n",
|
||
"20406 0.000017\n",
|
||
"74443 0.000324\n",
|
||
"74306 0.000185\n",
|
||
"45943 0.000186\n",
|
||
"62859 0.000041\n",
|
||
"... ...\n",
|
||
"51634 0.000041\n",
|
||
"85083 0.000066\n",
|
||
"38905 0.000031\n",
|
||
"16144 0.002906\n",
|
||
"54508 0.000059\n",
|
||
"\n",
|
||
"[18168 rows x 1 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from typing import Tuple\n",
|
||
"import pandas as pd\n",
|
||
"from pandas import DataFrame\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"\n",
|
||
"def split_into_train_test(\n",
|
||
" df_input: DataFrame,\n",
|
||
" target_colname: str = \"impact_damage_index\",\n",
|
||
" frac_train: float = 0.8,\n",
|
||
" random_state: int = None,\n",
|
||
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
||
" \n",
|
||
" if not (0 < frac_train < 1):\n",
|
||
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
|
||
" \n",
|
||
" # Проверка наличия целевого признака\n",
|
||
" if target_colname not in df_input.columns:\n",
|
||
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
|
||
" \n",
|
||
" # Разделяем данные на признаки и целевую переменную\n",
|
||
" X = df_input.drop(columns=[target_colname]) # Признаки\n",
|
||
" y = df_input[[target_colname]] # Целевая переменная\n",
|
||
"\n",
|
||
" # Удаляем указанные столбцы из X\n",
|
||
" columns_to_remove = [\"sentry_object\", \"hazardous\", \"orbiting_body\", \"name\"]\n",
|
||
" X = X.drop(columns=columns_to_remove, errors='ignore') # Игнорировать ошибку, если столбцы не найдены\n",
|
||
"\n",
|
||
" # Разделяем данные на обучающую и тестовую выборки\n",
|
||
" X_train, X_test, y_train, y_test = train_test_split(\n",
|
||
" X, y,\n",
|
||
" test_size=(1.0 - frac_train),\n",
|
||
" random_state=random_state\n",
|
||
" )\n",
|
||
" \n",
|
||
" return X_train, X_test, y_train, y_test\n",
|
||
"\n",
|
||
"# Применение функции для разделения данных\n",
|
||
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
|
||
" df, \n",
|
||
" target_colname=\"impact_damage_index\", \n",
|
||
" frac_train=0.8, \n",
|
||
" random_state=42\n",
|
||
")\n",
|
||
"\n",
|
||
"# Для отображения результатов\n",
|
||
"display(\"X_train\", X_train)\n",
|
||
"display(\"y_train\", y_train)\n",
|
||
"\n",
|
||
"display(\"X_test\", X_test)\n",
|
||
"display(\"y_test\", y_test)\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Определение перечня алгоритмов решения задачи аппроксимации (регрессии)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 204,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.pipeline import make_pipeline\n",
|
||
"from sklearn.preprocessing import PolynomialFeatures\n",
|
||
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
|
||
"\n",
|
||
"random_state = 9\n",
|
||
"\n",
|
||
"models = {\n",
|
||
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
|
||
" \"linear_poly\": {\n",
|
||
" \"model\": make_pipeline(\n",
|
||
" PolynomialFeatures(degree=2),\n",
|
||
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
||
" )\n",
|
||
" },\n",
|
||
" \"linear_interact\": {\n",
|
||
" \"model\": make_pipeline(\n",
|
||
" PolynomialFeatures(interaction_only=True),\n",
|
||
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
||
" )\n",
|
||
" },\n",
|
||
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
|
||
" \"decision_tree\": {\n",
|
||
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
|
||
" },\n",
|
||
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
|
||
" \"random_forest\": {\n",
|
||
" \"model\": ensemble.RandomForestRegressor(\n",
|
||
" max_depth=7, random_state=random_state, n_jobs=-1\n",
|
||
" )\n",
|
||
" },\n",
|
||
" \"mlp\": {\n",
|
||
" \"model\": neural_network.MLPRegressor(\n",
|
||
" activation=\"tanh\",\n",
|
||
" hidden_layer_sizes=(3,),\n",
|
||
" max_iter=500,\n",
|
||
" early_stopping=True,\n",
|
||
" random_state=random_state,\n",
|
||
" )\n",
|
||
" },\n",
|
||
"}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Формирование набора моделей для регрессии"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 205,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Model: linear\n",
|
||
"Model: linear_poly\n",
|
||
"Model: linear_interact\n",
|
||
"Model: ridge\n",
|
||
"Model: decision_tree\n",
|
||
"Model: knn\n",
|
||
"Model: random_forest\n",
|
||
"Model: mlp\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import math\n",
|
||
"from pandas import DataFrame\n",
|
||
"from sklearn import metrics\n",
|
||
"\n",
|
||
"for model_name in models.keys():\n",
|
||
" print(f\"Model: {model_name}\")\n",
|
||
"\n",
|
||
" fitted_model = models[model_name][\"model\"].fit(\n",
|
||
" X_train.values, y_train.values.ravel()\n",
|
||
" )\n",
|
||
" y_train_pred = fitted_model.predict(X_train.values)\n",
|
||
" y_test_pred = fitted_model.predict(X_test.values)\n",
|
||
" models[model_name][\"fitted\"] = fitted_model\n",
|
||
" models[model_name][\"train_preds\"] = y_train_pred\n",
|
||
" models[model_name][\"preds\"] = y_test_pred\n",
|
||
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
|
||
" metrics.mean_squared_error(y_train, y_train_pred)\n",
|
||
" )\n",
|
||
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
|
||
" metrics.mean_squared_error(y_test, y_test_pred)\n",
|
||
" )\n",
|
||
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
|
||
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
|
||
" )\n",
|
||
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Вывод результатов оценки"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 206,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_38ff3_row0_col0, #T_38ff3_row0_col1, #T_38ff3_row1_col0, #T_38ff3_row1_col1, #T_38ff3_row2_col0, #T_38ff3_row2_col1, #T_38ff3_row3_col0, #T_38ff3_row3_col1, #T_38ff3_row4_col0, #T_38ff3_row4_col1, #T_38ff3_row5_col0, #T_38ff3_row5_col1, #T_38ff3_row6_col0, #T_38ff3_row6_col1 {\n",
|
||
" background-color: #26818e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_38ff3_row0_col2, #T_38ff3_row1_col2, #T_38ff3_row2_col2, #T_38ff3_row3_col2, #T_38ff3_row4_col2, #T_38ff3_row5_col2, #T_38ff3_row7_col3 {\n",
|
||
" background-color: #4e02a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_38ff3_row0_col3, #T_38ff3_row1_col3, #T_38ff3_row2_col3, #T_38ff3_row3_col3, #T_38ff3_row4_col3, #T_38ff3_row5_col3, #T_38ff3_row6_col3, #T_38ff3_row7_col2 {\n",
|
||
" background-color: #da5a6a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_38ff3_row6_col2 {\n",
|
||
" background-color: #5002a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_38ff3_row7_col0, #T_38ff3_row7_col1 {\n",
|
||
" background-color: #a8db34;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_38ff3\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_38ff3_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
|
||
" <th id=\"T_38ff3_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
|
||
" <th id=\"T_38ff3_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
|
||
" <th id=\"T_38ff3_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_38ff3_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
|
||
" <td id=\"T_38ff3_row0_col0\" class=\"data row0 col0\" >0.000409</td>\n",
|
||
" <td id=\"T_38ff3_row0_col1\" class=\"data row0 col1\" >0.000711</td>\n",
|
||
" <td id=\"T_38ff3_row0_col2\" class=\"data row0 col2\" >0.012593</td>\n",
|
||
" <td id=\"T_38ff3_row0_col3\" class=\"data row0 col3\" >0.852564</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_38ff3_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
|
||
" <td id=\"T_38ff3_row1_col0\" class=\"data row1 col0\" >0.000511</td>\n",
|
||
" <td id=\"T_38ff3_row1_col1\" class=\"data row1 col1\" >0.001031</td>\n",
|
||
" <td id=\"T_38ff3_row1_col2\" class=\"data row1 col2\" >0.015170</td>\n",
|
||
" <td id=\"T_38ff3_row1_col3\" class=\"data row1 col3\" >0.689858</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_38ff3_level0_row2\" class=\"row_heading level0 row2\" >linear_poly</th>\n",
|
||
" <td id=\"T_38ff3_row2_col0\" class=\"data row2 col0\" >0.001217</td>\n",
|
||
" <td id=\"T_38ff3_row2_col1\" class=\"data row2 col1\" >0.001476</td>\n",
|
||
" <td id=\"T_38ff3_row2_col2\" class=\"data row2 col2\" >0.018001</td>\n",
|
||
" <td id=\"T_38ff3_row2_col3\" class=\"data row2 col3\" >0.364795</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_38ff3_level0_row3\" class=\"row_heading level0 row3\" >linear_interact</th>\n",
|
||
" <td id=\"T_38ff3_row3_col0\" class=\"data row3 col0\" >0.001263</td>\n",
|
||
" <td id=\"T_38ff3_row3_col1\" class=\"data row3 col1\" >0.001500</td>\n",
|
||
" <td id=\"T_38ff3_row3_col2\" class=\"data row3 col2\" >0.018235</td>\n",
|
||
" <td id=\"T_38ff3_row3_col3\" class=\"data row3 col3\" >0.343354</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_38ff3_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
|
||
" <td id=\"T_38ff3_row4_col0\" class=\"data row4 col0\" >0.001206</td>\n",
|
||
" <td id=\"T_38ff3_row4_col1\" class=\"data row4 col1\" >0.001611</td>\n",
|
||
" <td id=\"T_38ff3_row4_col2\" class=\"data row4 col2\" >0.019245</td>\n",
|
||
" <td id=\"T_38ff3_row4_col3\" class=\"data row4 col3\" >0.243014</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_38ff3_level0_row5\" class=\"row_heading level0 row5\" >linear</th>\n",
|
||
" <td id=\"T_38ff3_row5_col0\" class=\"data row5 col0\" >0.001382</td>\n",
|
||
" <td id=\"T_38ff3_row5_col1\" class=\"data row5 col1\" >0.001629</td>\n",
|
||
" <td id=\"T_38ff3_row5_col2\" class=\"data row5 col2\" >0.019724</td>\n",
|
||
" <td id=\"T_38ff3_row5_col3\" class=\"data row5 col3\" >0.225851</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_38ff3_level0_row6\" class=\"row_heading level0 row6\" >mlp</th>\n",
|
||
" <td id=\"T_38ff3_row6_col0\" class=\"data row6 col0\" >0.001610</td>\n",
|
||
" <td id=\"T_38ff3_row6_col1\" class=\"data row6 col1\" >0.001852</td>\n",
|
||
" <td id=\"T_38ff3_row6_col2\" class=\"data row6 col2\" >0.023283</td>\n",
|
||
" <td id=\"T_38ff3_row6_col3\" class=\"data row6 col3\" >-0.000074</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_38ff3_level0_row7\" class=\"row_heading level0 row7\" >ridge</th>\n",
|
||
" <td id=\"T_38ff3_row7_col0\" class=\"data row7 col0\" >2.251826</td>\n",
|
||
" <td id=\"T_38ff3_row7_col1\" class=\"data row7 col1\" >2.248301</td>\n",
|
||
" <td id=\"T_38ff3_row7_col2\" class=\"data row7 col2\" >1.349327</td>\n",
|
||
" <td id=\"T_38ff3_row7_col3\" class=\"data row7 col3\" >-1474534.430780</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x1b4c88a89e0>"
|
||
]
|
||
},
|
||
"execution_count": 206,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
|
||
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
|
||
"]\n",
|
||
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
|
||
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
|
||
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Вывод реального и \"спрогнозированного\" результата для обучающей и тестовой выборок\n",
|
||
"\n",
|
||
"Получение лучшей модели"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 207,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'random_forest'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
|
||
"\n",
|
||
"display(best_model)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Подбор гиперпараметров методом поиска по сетке"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 209,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Fitting 3 folds for each of 8 candidates, totalling 24 fits\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
||
" return fit_method(estimator, *args, **kwargs)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Лучшие параметры: {'max_depth': 10, 'min_samples_split': 5, 'n_estimators': 50}\n",
|
||
"Лучший результат (MSE): 5.418559949534169e-07\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"import numpy as np\n",
|
||
"from sklearn import metrics\n",
|
||
"from sklearn.pipeline import Pipeline\n",
|
||
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
|
||
"from sklearn.ensemble import RandomForestRegressor # Используем регрессор\n",
|
||
"from sklearn.preprocessing import StandardScaler\n",
|
||
"\n",
|
||
"\n",
|
||
"df.dropna(inplace=True) \n",
|
||
"# Предикторы и целевая переменная\n",
|
||
"X = df[[\"est_diameter_min\", \"est_diameter_max\", \"relative_velocity\", \"miss_distance\", \"absolute_magnitude\"]]\n",
|
||
"y = df['impact_damage_index'] # Целевая переменная для регрессии\n",
|
||
"\n",
|
||
"\n",
|
||
"model = RandomForestRegressor() \n",
|
||
"\n",
|
||
"param_grid = {\n",
|
||
" 'n_estimators': [50, 100], \n",
|
||
" 'max_depth': [10, 20], \n",
|
||
" 'min_samples_split': [5, 10] \n",
|
||
"}\n",
|
||
"\n",
|
||
"# 3. Подбор гиперпараметров с помощью Grid Search\n",
|
||
"grid_search = GridSearchCV(estimator=model, param_grid=param_grid,\n",
|
||
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
|
||
"\n",
|
||
"# Обучение модели на тренировочных данных\n",
|
||
"grid_search.fit(X_train, y_train)\n",
|
||
"\n",
|
||
"# 4. Результаты подбора гиперпараметров\n",
|
||
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
|
||
"print(\"Лучший результат (MSE):\", -grid_search.best_score_)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Обучение модели с новыми гиперпараметрами и сравнение новых и старых данных"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 210,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Fitting 3 folds for each of 8 candidates, totalling 24 fits\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
||
" return fit_method(estimator, *args, **kwargs)\n",
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
||
" return fit_method(estimator, *args, **kwargs)\n",
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
||
" return fit_method(estimator, *args, **kwargs)\n",
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
||
" return fit_method(estimator, *args, **kwargs)\n",
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
||
" return fit_method(estimator, *args, **kwargs)\n",
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
||
" return fit_method(estimator, *args, **kwargs)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Старые параметры: {'max_depth': 20, 'min_samples_split': 5, 'n_estimators': 50}\n",
|
||
"Лучший результат (MSE) на старых параметрах: 5.299415148966497e-07\n",
|
||
"\n",
|
||
"Новые параметры: {'max_depth': 20, 'min_samples_split': 10, 'n_estimators': 100}\n",
|
||
"Лучший результат (MSE) на новых параметрах: 5.355742455463778e-07\n",
|
||
"Среднеквадратическая ошибка (MSE) на тестовых данных: 4.772832137780905e-07\n",
|
||
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 0.0006908568692414446\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"import numpy as np\n",
|
||
"from sklearn import metrics\n",
|
||
"from sklearn.ensemble import RandomForestRegressor\n",
|
||
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"\n",
|
||
"\n",
|
||
"old_param_grid = {\n",
|
||
" 'n_estimators': [50, 100], # Количество деревьев\n",
|
||
" 'max_depth': [ 10, 20], # Максимальная глубина дерева\n",
|
||
" 'min_samples_split': [5, 10] # Минимальное количество образцов для разбиения узла\n",
|
||
"}\n",
|
||
"\n",
|
||
"old_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
|
||
" param_grid=old_param_grid,\n",
|
||
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
|
||
"\n",
|
||
"old_grid_search.fit(X_train, y_train)\n",
|
||
"\n",
|
||
"old_best_params = old_grid_search.best_params_\n",
|
||
"old_best_mse = -old_grid_search.best_score_ # Меняем знак, так как берем отрицательное значение MSE\n",
|
||
"\n",
|
||
"new_param_grid = {\n",
|
||
" 'n_estimators': [100],\n",
|
||
" 'max_depth': [20],\n",
|
||
" 'min_samples_split': [10]\n",
|
||
"}\n",
|
||
"\n",
|
||
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
|
||
" param_grid=new_param_grid,\n",
|
||
" scoring='neg_mean_squared_error', cv=2)\n",
|
||
"\n",
|
||
"new_grid_search.fit(X_train, y_train)\n",
|
||
"\n",
|
||
"new_best_params = new_grid_search.best_params_\n",
|
||
"new_best_mse = -new_grid_search.best_score_ # Меняем знак, так как берем отрицательное значение MSE\n",
|
||
"\n",
|
||
"model_best = RandomForestRegressor(**new_best_params)\n",
|
||
"model_best.fit(X_train, y_train)\n",
|
||
"\n",
|
||
"model_oldbest = RandomForestRegressor(**old_best_params)\n",
|
||
"model_oldbest.fit(X_train, y_train)\n",
|
||
"\n",
|
||
"y_pred = model_best.predict(X_test)\n",
|
||
"y_oldpred = model_oldbest.predict(X_test)\n",
|
||
"\n",
|
||
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
|
||
"rmse = np.sqrt(mse)\n",
|
||
"\n",
|
||
"print(\"Старые параметры:\", old_best_params)\n",
|
||
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
|
||
"print(\"\\nНовые параметры:\", new_best_params)\n",
|
||
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
|
||
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
|
||
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Попробуем визуализировать"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 212,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 1000x500 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"plt.figure(figsize=(10, 5))\n",
|
||
"plt.scatter(range(len(y_test)), y_test, label=\"Актуалочка\", color=\"black\", alpha=0.5)\n",
|
||
"plt.scatter(range(len(y_test)), y_pred, label=\"Предсказанные(новые параметры)\", color=\"blue\", alpha=0.5)\n",
|
||
"plt.scatter(range(len(y_test)), y_oldpred, label=\"Предсказанные(старые параметры)\", color=\"red\", alpha=0.5)\n",
|
||
"plt.xlabel(\"Выборка\")\n",
|
||
"plt.ylabel(\"Значения\")\n",
|
||
"plt.legend()\n",
|
||
"plt.title(\"Актуалочка vs Предсказанных значений (Новые and Старые Параметры)\")\n",
|
||
"plt.show()"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "aimenv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.12.6"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|