3935 lines
320 KiB
Plaintext
3935 lines
320 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Начало 4-й лабораторной\n",
|
||
"#### Ближайшие объекты к Земле"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Index(['id', 'name', 'est_diameter_min', 'est_diameter_max',\n",
|
||
" 'relative_velocity', 'miss_distance', 'orbiting_body', 'sentry_object',\n",
|
||
" 'absolute_magnitude', 'hazardous'],\n",
|
||
" dtype='object')\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>id</th>\n",
|
||
" <th>name</th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>orbiting_body</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>2162635</td>\n",
|
||
" <td>162635 (2000 SS164)</td>\n",
|
||
" <td>1.198271</td>\n",
|
||
" <td>2.679415</td>\n",
|
||
" <td>13569.249224</td>\n",
|
||
" <td>5.483974e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>16.73</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>2277475</td>\n",
|
||
" <td>277475 (2005 WK4)</td>\n",
|
||
" <td>0.265800</td>\n",
|
||
" <td>0.594347</td>\n",
|
||
" <td>73588.726663</td>\n",
|
||
" <td>6.143813e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>20.00</td>\n",
|
||
" <td>True</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>2512244</td>\n",
|
||
" <td>512244 (2015 YE18)</td>\n",
|
||
" <td>0.722030</td>\n",
|
||
" <td>1.614507</td>\n",
|
||
" <td>114258.692129</td>\n",
|
||
" <td>4.979872e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>17.83</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>3596030</td>\n",
|
||
" <td>(2012 BV13)</td>\n",
|
||
" <td>0.096506</td>\n",
|
||
" <td>0.215794</td>\n",
|
||
" <td>24764.303138</td>\n",
|
||
" <td>2.543497e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>22.20</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>3667127</td>\n",
|
||
" <td>(2014 GE35)</td>\n",
|
||
" <td>0.255009</td>\n",
|
||
" <td>0.570217</td>\n",
|
||
" <td>42737.733765</td>\n",
|
||
" <td>4.627557e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>20.09</td>\n",
|
||
" <td>True</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90831</th>\n",
|
||
" <td>3763337</td>\n",
|
||
" <td>(2016 VX1)</td>\n",
|
||
" <td>0.026580</td>\n",
|
||
" <td>0.059435</td>\n",
|
||
" <td>52078.886692</td>\n",
|
||
" <td>1.230039e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>25.00</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90832</th>\n",
|
||
" <td>3837603</td>\n",
|
||
" <td>(2019 AD3)</td>\n",
|
||
" <td>0.016771</td>\n",
|
||
" <td>0.037501</td>\n",
|
||
" <td>46114.605073</td>\n",
|
||
" <td>5.432121e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>26.00</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90833</th>\n",
|
||
" <td>54017201</td>\n",
|
||
" <td>(2020 JP3)</td>\n",
|
||
" <td>0.031956</td>\n",
|
||
" <td>0.071456</td>\n",
|
||
" <td>7566.807732</td>\n",
|
||
" <td>2.840077e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>24.60</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90834</th>\n",
|
||
" <td>54115824</td>\n",
|
||
" <td>(2021 CN5)</td>\n",
|
||
" <td>0.007321</td>\n",
|
||
" <td>0.016370</td>\n",
|
||
" <td>69199.154484</td>\n",
|
||
" <td>6.869206e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>27.80</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90835</th>\n",
|
||
" <td>54205447</td>\n",
|
||
" <td>(2021 TW7)</td>\n",
|
||
" <td>0.039862</td>\n",
|
||
" <td>0.089133</td>\n",
|
||
" <td>27024.455553</td>\n",
|
||
" <td>5.977213e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>24.12</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>90836 rows × 10 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" id name est_diameter_min est_diameter_max \\\n",
|
||
"0 2162635 162635 (2000 SS164) 1.198271 2.679415 \n",
|
||
"1 2277475 277475 (2005 WK4) 0.265800 0.594347 \n",
|
||
"2 2512244 512244 (2015 YE18) 0.722030 1.614507 \n",
|
||
"3 3596030 (2012 BV13) 0.096506 0.215794 \n",
|
||
"4 3667127 (2014 GE35) 0.255009 0.570217 \n",
|
||
"... ... ... ... ... \n",
|
||
"90831 3763337 (2016 VX1) 0.026580 0.059435 \n",
|
||
"90832 3837603 (2019 AD3) 0.016771 0.037501 \n",
|
||
"90833 54017201 (2020 JP3) 0.031956 0.071456 \n",
|
||
"90834 54115824 (2021 CN5) 0.007321 0.016370 \n",
|
||
"90835 54205447 (2021 TW7) 0.039862 0.089133 \n",
|
||
"\n",
|
||
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
|
||
"0 13569.249224 5.483974e+07 Earth False \n",
|
||
"1 73588.726663 6.143813e+07 Earth False \n",
|
||
"2 114258.692129 4.979872e+07 Earth False \n",
|
||
"3 24764.303138 2.543497e+07 Earth False \n",
|
||
"4 42737.733765 4.627557e+07 Earth False \n",
|
||
"... ... ... ... ... \n",
|
||
"90831 52078.886692 1.230039e+07 Earth False \n",
|
||
"90832 46114.605073 5.432121e+07 Earth False \n",
|
||
"90833 7566.807732 2.840077e+07 Earth False \n",
|
||
"90834 69199.154484 6.869206e+07 Earth False \n",
|
||
"90835 27024.455553 5.977213e+07 Earth False \n",
|
||
"\n",
|
||
" absolute_magnitude hazardous \n",
|
||
"0 16.73 False \n",
|
||
"1 20.00 True \n",
|
||
"2 17.83 False \n",
|
||
"3 22.20 False \n",
|
||
"4 20.09 True \n",
|
||
"... ... ... \n",
|
||
"90831 25.00 False \n",
|
||
"90832 26.00 False \n",
|
||
"90833 24.60 False \n",
|
||
"90834 27.80 False \n",
|
||
"90835 24.12 False \n",
|
||
"\n",
|
||
"[90836 rows x 10 columns]"
|
||
]
|
||
},
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import seaborn as sns\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"from sklearn import set_config\n",
|
||
"\n",
|
||
"set_config(transform_output=\"pandas\")\n",
|
||
"df = pd.read_csv(\".//static//csv//neo.csv\")\n",
|
||
"print(df.columns)\n",
|
||
"df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Бизнес-цели:\n",
|
||
"\n",
|
||
"1. Идентификация потенциально опасных объектов\n",
|
||
"\n",
|
||
"Описание: классифицировать астероиды как потенциально опасные или безопасные (используя целевой признак \"hazardous\"). Эта задача актуальна для оценки рисков и подготовки соответствующих действий по защите Земли.\n",
|
||
"\n",
|
||
"2. Прогнозирование минимального расстояния до Земли\n",
|
||
"\n",
|
||
"Описание: предсказать минимальное расстояние до Земли для новых объектов на основе характеристик астероида (скорости, размера и других параметров). Это позволит планировать исследования и наблюдения в зависимости от опасности. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Определение достижимого уровня качества модели для первой задачи "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
|
||
"\n",
|
||
"Целевой признак -- hazardous"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>id</th>\n",
|
||
" <th>name</th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>orbiting_body</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>2639</th>\n",
|
||
" <td>3634614</td>\n",
|
||
" <td>(2013 GT66)</td>\n",
|
||
" <td>0.024241</td>\n",
|
||
" <td>0.054205</td>\n",
|
||
" <td>43303.999094</td>\n",
|
||
" <td>4.814117e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>25.20</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>29138</th>\n",
|
||
" <td>54143560</td>\n",
|
||
" <td>(2021 JU1)</td>\n",
|
||
" <td>0.030238</td>\n",
|
||
" <td>0.067615</td>\n",
|
||
" <td>21770.790211</td>\n",
|
||
" <td>5.646643e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>24.72</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>36927</th>\n",
|
||
" <td>3836085</td>\n",
|
||
" <td>(2018 VQ3)</td>\n",
|
||
" <td>0.201630</td>\n",
|
||
" <td>0.450858</td>\n",
|
||
" <td>109358.123029</td>\n",
|
||
" <td>6.435051e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>20.60</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>61855</th>\n",
|
||
" <td>3769804</td>\n",
|
||
" <td>(2017 DJ34)</td>\n",
|
||
" <td>0.160160</td>\n",
|
||
" <td>0.358129</td>\n",
|
||
" <td>78494.609756</td>\n",
|
||
" <td>5.595780e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>21.10</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>15916</th>\n",
|
||
" <td>3824978</td>\n",
|
||
" <td>(2018 KS)</td>\n",
|
||
" <td>0.006991</td>\n",
|
||
" <td>0.015633</td>\n",
|
||
" <td>19077.749486</td>\n",
|
||
" <td>3.834648e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>27.90</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>29491</th>\n",
|
||
" <td>3827304</td>\n",
|
||
" <td>(2018 RR1)</td>\n",
|
||
" <td>0.002658</td>\n",
|
||
" <td>0.005943</td>\n",
|
||
" <td>19826.895880</td>\n",
|
||
" <td>3.852881e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>30.00</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>18373</th>\n",
|
||
" <td>3735468</td>\n",
|
||
" <td>(2015 WY1)</td>\n",
|
||
" <td>0.103408</td>\n",
|
||
" <td>0.231228</td>\n",
|
||
" <td>82856.544926</td>\n",
|
||
" <td>7.314334e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>22.05</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>25031</th>\n",
|
||
" <td>3802041</td>\n",
|
||
" <td>(2018 FE3)</td>\n",
|
||
" <td>0.009651</td>\n",
|
||
" <td>0.021579</td>\n",
|
||
" <td>34243.774201</td>\n",
|
||
" <td>4.257719e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>27.20</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>35456</th>\n",
|
||
" <td>3430406</td>\n",
|
||
" <td>(2008 TR10)</td>\n",
|
||
" <td>0.221083</td>\n",
|
||
" <td>0.494356</td>\n",
|
||
" <td>19557.289783</td>\n",
|
||
" <td>2.152970e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>20.40</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>14305</th>\n",
|
||
" <td>3285300</td>\n",
|
||
" <td>(2005 OG3)</td>\n",
|
||
" <td>0.298233</td>\n",
|
||
" <td>0.666868</td>\n",
|
||
" <td>20309.404706</td>\n",
|
||
" <td>1.770015e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>19.75</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>72668 rows × 10 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" id name est_diameter_min est_diameter_max \\\n",
|
||
"2639 3634614 (2013 GT66) 0.024241 0.054205 \n",
|
||
"29138 54143560 (2021 JU1) 0.030238 0.067615 \n",
|
||
"36927 3836085 (2018 VQ3) 0.201630 0.450858 \n",
|
||
"61855 3769804 (2017 DJ34) 0.160160 0.358129 \n",
|
||
"15916 3824978 (2018 KS) 0.006991 0.015633 \n",
|
||
"... ... ... ... ... \n",
|
||
"29491 3827304 (2018 RR1) 0.002658 0.005943 \n",
|
||
"18373 3735468 (2015 WY1) 0.103408 0.231228 \n",
|
||
"25031 3802041 (2018 FE3) 0.009651 0.021579 \n",
|
||
"35456 3430406 (2008 TR10) 0.221083 0.494356 \n",
|
||
"14305 3285300 (2005 OG3) 0.298233 0.666868 \n",
|
||
"\n",
|
||
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
|
||
"2639 43303.999094 4.814117e+07 Earth False \n",
|
||
"29138 21770.790211 5.646643e+07 Earth False \n",
|
||
"36927 109358.123029 6.435051e+07 Earth False \n",
|
||
"61855 78494.609756 5.595780e+07 Earth False \n",
|
||
"15916 19077.749486 3.834648e+07 Earth False \n",
|
||
"... ... ... ... ... \n",
|
||
"29491 19826.895880 3.852881e+07 Earth False \n",
|
||
"18373 82856.544926 7.314334e+07 Earth False \n",
|
||
"25031 34243.774201 4.257719e+07 Earth False \n",
|
||
"35456 19557.289783 2.152970e+07 Earth False \n",
|
||
"14305 20309.404706 1.770015e+07 Earth False \n",
|
||
"\n",
|
||
" absolute_magnitude hazardous \n",
|
||
"2639 25.20 False \n",
|
||
"29138 24.72 False \n",
|
||
"36927 20.60 False \n",
|
||
"61855 21.10 False \n",
|
||
"15916 27.90 False \n",
|
||
"... ... ... \n",
|
||
"29491 30.00 False \n",
|
||
"18373 22.05 False \n",
|
||
"25031 27.20 False \n",
|
||
"35456 20.40 False \n",
|
||
"14305 19.75 False \n",
|
||
"\n",
|
||
"[72668 rows x 10 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>2639</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>29138</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>36927</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>61855</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>15916</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>29491</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>18373</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>25031</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>35456</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>14305</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>72668 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" hazardous\n",
|
||
"2639 False\n",
|
||
"29138 False\n",
|
||
"36927 False\n",
|
||
"61855 False\n",
|
||
"15916 False\n",
|
||
"... ...\n",
|
||
"29491 False\n",
|
||
"18373 False\n",
|
||
"25031 False\n",
|
||
"35456 False\n",
|
||
"14305 False\n",
|
||
"\n",
|
||
"[72668 rows x 1 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>id</th>\n",
|
||
" <th>name</th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>orbiting_body</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>9040</th>\n",
|
||
" <td>2474532</td>\n",
|
||
" <td>474532 (2003 VG1)</td>\n",
|
||
" <td>0.472667</td>\n",
|
||
" <td>1.056915</td>\n",
|
||
" <td>21779.237137</td>\n",
|
||
" <td>3.443050e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>18.75</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>67305</th>\n",
|
||
" <td>3774018</td>\n",
|
||
" <td>(2017 HF1)</td>\n",
|
||
" <td>0.084053</td>\n",
|
||
" <td>0.187949</td>\n",
|
||
" <td>53291.016226</td>\n",
|
||
" <td>6.862591e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>22.50</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>77741</th>\n",
|
||
" <td>54269585</td>\n",
|
||
" <td>(2022 GQ2)</td>\n",
|
||
" <td>0.018220</td>\n",
|
||
" <td>0.040742</td>\n",
|
||
" <td>43089.046433</td>\n",
|
||
" <td>2.592726e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>25.82</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>81520</th>\n",
|
||
" <td>54097970</td>\n",
|
||
" <td>(2020 XS)</td>\n",
|
||
" <td>0.152952</td>\n",
|
||
" <td>0.342011</td>\n",
|
||
" <td>93246.455599</td>\n",
|
||
" <td>4.709054e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>21.20</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>508</th>\n",
|
||
" <td>3730802</td>\n",
|
||
" <td>(2015 TT238)</td>\n",
|
||
" <td>0.031956</td>\n",
|
||
" <td>0.071456</td>\n",
|
||
" <td>37708.258544</td>\n",
|
||
" <td>4.232149e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>24.60</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>28261</th>\n",
|
||
" <td>3532365</td>\n",
|
||
" <td>(2010 MH1)</td>\n",
|
||
" <td>0.139494</td>\n",
|
||
" <td>0.311918</td>\n",
|
||
" <td>37604.980238</td>\n",
|
||
" <td>7.369507e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>21.40</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1159</th>\n",
|
||
" <td>54073345</td>\n",
|
||
" <td>(2020 UE)</td>\n",
|
||
" <td>0.020728</td>\n",
|
||
" <td>0.046349</td>\n",
|
||
" <td>36720.077728</td>\n",
|
||
" <td>3.366114e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>25.54</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>48095</th>\n",
|
||
" <td>3836195</td>\n",
|
||
" <td>(2018 VT7)</td>\n",
|
||
" <td>0.006991</td>\n",
|
||
" <td>0.015633</td>\n",
|
||
" <td>7616.496535</td>\n",
|
||
" <td>6.376350e+06</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>27.90</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90234</th>\n",
|
||
" <td>3752902</td>\n",
|
||
" <td>(2016 JG12)</td>\n",
|
||
" <td>0.084053</td>\n",
|
||
" <td>0.187949</td>\n",
|
||
" <td>21894.554692</td>\n",
|
||
" <td>5.736984e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>22.50</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>12013</th>\n",
|
||
" <td>3445077</td>\n",
|
||
" <td>(2009 BM58)</td>\n",
|
||
" <td>0.038420</td>\n",
|
||
" <td>0.085909</td>\n",
|
||
" <td>49828.611609</td>\n",
|
||
" <td>4.305599e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>24.20</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>18168 rows × 10 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" id name est_diameter_min est_diameter_max \\\n",
|
||
"9040 2474532 474532 (2003 VG1) 0.472667 1.056915 \n",
|
||
"67305 3774018 (2017 HF1) 0.084053 0.187949 \n",
|
||
"77741 54269585 (2022 GQ2) 0.018220 0.040742 \n",
|
||
"81520 54097970 (2020 XS) 0.152952 0.342011 \n",
|
||
"508 3730802 (2015 TT238) 0.031956 0.071456 \n",
|
||
"... ... ... ... ... \n",
|
||
"28261 3532365 (2010 MH1) 0.139494 0.311918 \n",
|
||
"1159 54073345 (2020 UE) 0.020728 0.046349 \n",
|
||
"48095 3836195 (2018 VT7) 0.006991 0.015633 \n",
|
||
"90234 3752902 (2016 JG12) 0.084053 0.187949 \n",
|
||
"12013 3445077 (2009 BM58) 0.038420 0.085909 \n",
|
||
"\n",
|
||
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
|
||
"9040 21779.237137 3.443050e+07 Earth False \n",
|
||
"67305 53291.016226 6.862591e+07 Earth False \n",
|
||
"77741 43089.046433 2.592726e+07 Earth False \n",
|
||
"81520 93246.455599 4.709054e+07 Earth False \n",
|
||
"508 37708.258544 4.232149e+07 Earth False \n",
|
||
"... ... ... ... ... \n",
|
||
"28261 37604.980238 7.369507e+07 Earth False \n",
|
||
"1159 36720.077728 3.366114e+07 Earth False \n",
|
||
"48095 7616.496535 6.376350e+06 Earth False \n",
|
||
"90234 21894.554692 5.736984e+07 Earth False \n",
|
||
"12013 49828.611609 4.305599e+07 Earth False \n",
|
||
"\n",
|
||
" absolute_magnitude hazardous \n",
|
||
"9040 18.75 False \n",
|
||
"67305 22.50 False \n",
|
||
"77741 25.82 False \n",
|
||
"81520 21.20 False \n",
|
||
"508 24.60 False \n",
|
||
"... ... ... \n",
|
||
"28261 21.40 False \n",
|
||
"1159 25.54 False \n",
|
||
"48095 27.90 False \n",
|
||
"90234 22.50 False \n",
|
||
"12013 24.20 False \n",
|
||
"\n",
|
||
"[18168 rows x 10 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>9040</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>67305</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>77741</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>81520</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>508</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>28261</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1159</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>48095</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90234</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>12013</th>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>18168 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" hazardous\n",
|
||
"9040 False\n",
|
||
"67305 False\n",
|
||
"77741 False\n",
|
||
"81520 False\n",
|
||
"508 False\n",
|
||
"... ...\n",
|
||
"28261 False\n",
|
||
"1159 False\n",
|
||
"48095 False\n",
|
||
"90234 False\n",
|
||
"12013 False\n",
|
||
"\n",
|
||
"[18168 rows x 1 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from typing import Tuple\n",
|
||
"import pandas as pd\n",
|
||
"from pandas import DataFrame\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"\n",
|
||
"# Устанавливаем случайное состояние\n",
|
||
"random_state = 42\n",
|
||
"\n",
|
||
"def split_stratified_into_train_val_test(\n",
|
||
" df_input,\n",
|
||
" stratify_colname=\"y\",\n",
|
||
" frac_train=0.6,\n",
|
||
" frac_val=0.15,\n",
|
||
" frac_test=0.25,\n",
|
||
" random_state=None,\n",
|
||
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
||
" \n",
|
||
" if frac_train + frac_val + frac_test != 1.0:\n",
|
||
" raise ValueError(\n",
|
||
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
|
||
" % (frac_train, frac_val, frac_test)\n",
|
||
" )\n",
|
||
" if stratify_colname not in df_input.columns:\n",
|
||
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
|
||
" X = df_input # Contains all columns.\n",
|
||
" y = df_input[\n",
|
||
" [stratify_colname]\n",
|
||
" ] # Dataframe of just the column on which to stratify.\n",
|
||
" # Split original dataframe into train and temp dataframes.\n",
|
||
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
|
||
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
|
||
" )\n",
|
||
" if frac_val <= 0:\n",
|
||
" assert len(df_input) == len(df_train) + len(df_temp)\n",
|
||
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
|
||
" # Split the temp dataframe into val and test dataframes.\n",
|
||
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
|
||
" df_val, df_test, y_val, y_test = train_test_split(\n",
|
||
" df_temp,\n",
|
||
" y_temp,\n",
|
||
" stratify=y_temp,\n",
|
||
" test_size=relative_frac_test,\n",
|
||
" random_state=random_state,\n",
|
||
" )\n",
|
||
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
|
||
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
|
||
"\n",
|
||
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
|
||
" df, stratify_colname=\"hazardous\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
|
||
")\n",
|
||
"\n",
|
||
"display(\"X_train\", X_train)\n",
|
||
"display(\"y_train\", y_train)\n",
|
||
"\n",
|
||
"display(\"X_test\", X_test)\n",
|
||
"display(\"y_test\", y_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Формирование конвейера для классификации данных\n",
|
||
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
|
||
"\n",
|
||
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
|
||
"\n",
|
||
"features_preprocessing -- трансформер для предобработки признаков\n",
|
||
"\n",
|
||
"features_engineering -- трансформер для конструирования признаков\n",
|
||
"\n",
|
||
"drop_columns -- трансформер для удаления колонок\n",
|
||
"\n",
|
||
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"from sklearn.base import BaseEstimator, TransformerMixin\n",
|
||
"from sklearn.compose import ColumnTransformer\n",
|
||
"from sklearn.discriminant_analysis import StandardScaler\n",
|
||
"from sklearn.impute import SimpleImputer\n",
|
||
"from sklearn.pipeline import Pipeline\n",
|
||
"from sklearn.preprocessing import OneHotEncoder\n",
|
||
"\n",
|
||
"class EarthObjectsFeatures(BaseEstimator, TransformerMixin):\n",
|
||
" def __init__(self):\n",
|
||
" pass\n",
|
||
" def fit(self, X, y=None):\n",
|
||
" return self\n",
|
||
" def transform(self, X, y=None):\n",
|
||
" X[\"Length_to_Width_Ratio\"] = X[\"x\"] / X[\"y\"]\n",
|
||
" return X\n",
|
||
" def get_feature_names_out(self, features_in):\n",
|
||
" return np.append(features_in, [\"Length_to_Width_Ratio\"], axis=0)\n",
|
||
" \n",
|
||
"\n",
|
||
"columns_to_drop = [\"name\", \"orbiting_body\"]\n",
|
||
"num_columns = [\"est_diameter_min\", \"est_diameter_max\",\n",
|
||
" \"relative_velocity\", \"miss_distance\", \"sentry_object\",\n",
|
||
" \"absolute_magnitude\", \"hazardous\"]\n",
|
||
"cat_columns = []\n",
|
||
"\n",
|
||
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
||
"num_scaler = StandardScaler()\n",
|
||
"preprocessing_num = Pipeline(\n",
|
||
" [\n",
|
||
" (\"imputer\", num_imputer),\n",
|
||
" (\"scaler\", num_scaler),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
||
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
||
"preprocessing_cat = Pipeline(\n",
|
||
" [\n",
|
||
" (\"imputer\", cat_imputer),\n",
|
||
" (\"encoder\", cat_encoder),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"features_preprocessing = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
|
||
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\"\n",
|
||
")\n",
|
||
"\n",
|
||
"\n",
|
||
"drop_columns = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\",\n",
|
||
")\n",
|
||
"\n",
|
||
"features_postprocessing = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"prepocessing_cat\", preprocessing_cat, [\"Cabin_type\"]),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\",\n",
|
||
")\n",
|
||
"\n",
|
||
"pipeline_end = Pipeline(\n",
|
||
" [\n",
|
||
" (\"features_preprocessing\", features_preprocessing),\n",
|
||
" (\"drop_columns\", drop_columns),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Демонстрация работы конвейера"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" <th>id</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>2639</th>\n",
|
||
" <td>-0.331616</td>\n",
|
||
" <td>-0.331616</td>\n",
|
||
" <td>-0.188160</td>\n",
|
||
" <td>0.494297</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.577785</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3634614</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>29138</th>\n",
|
||
" <td>-0.312486</td>\n",
|
||
" <td>-0.312486</td>\n",
|
||
" <td>-1.040729</td>\n",
|
||
" <td>0.866716</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.412170</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>54143560</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>36927</th>\n",
|
||
" <td>0.234246</td>\n",
|
||
" <td>0.234246</td>\n",
|
||
" <td>2.427134</td>\n",
|
||
" <td>1.219399</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>-1.009355</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3836085</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>61855</th>\n",
|
||
" <td>0.101960</td>\n",
|
||
" <td>0.101960</td>\n",
|
||
" <td>1.205148</td>\n",
|
||
" <td>0.843963</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>-0.836840</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3769804</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>15916</th>\n",
|
||
" <td>-0.386643</td>\n",
|
||
" <td>-0.386643</td>\n",
|
||
" <td>-1.147355</td>\n",
|
||
" <td>0.056145</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.509367</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3824978</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>29491</th>\n",
|
||
" <td>-0.400466</td>\n",
|
||
" <td>-0.400466</td>\n",
|
||
" <td>-1.117694</td>\n",
|
||
" <td>0.064301</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>2.233931</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3827304</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>18373</th>\n",
|
||
" <td>-0.079077</td>\n",
|
||
" <td>-0.079077</td>\n",
|
||
" <td>1.377851</td>\n",
|
||
" <td>1.612734</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>-0.509061</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3735468</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>25031</th>\n",
|
||
" <td>-0.378159</td>\n",
|
||
" <td>-0.378159</td>\n",
|
||
" <td>-0.546884</td>\n",
|
||
" <td>0.245400</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.267846</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3802041</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>35456</th>\n",
|
||
" <td>0.296300</td>\n",
|
||
" <td>0.296300</td>\n",
|
||
" <td>-1.128369</td>\n",
|
||
" <td>-0.696130</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>-1.078361</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3430406</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>14305</th>\n",
|
||
" <td>0.542404</td>\n",
|
||
" <td>0.542404</td>\n",
|
||
" <td>-1.098590</td>\n",
|
||
" <td>-0.867440</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>-1.302631</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3285300</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>72668 rows × 8 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" est_diameter_min est_diameter_max relative_velocity miss_distance \\\n",
|
||
"2639 -0.331616 -0.331616 -0.188160 0.494297 \n",
|
||
"29138 -0.312486 -0.312486 -1.040729 0.866716 \n",
|
||
"36927 0.234246 0.234246 2.427134 1.219399 \n",
|
||
"61855 0.101960 0.101960 1.205148 0.843963 \n",
|
||
"15916 -0.386643 -0.386643 -1.147355 0.056145 \n",
|
||
"... ... ... ... ... \n",
|
||
"29491 -0.400466 -0.400466 -1.117694 0.064301 \n",
|
||
"18373 -0.079077 -0.079077 1.377851 1.612734 \n",
|
||
"25031 -0.378159 -0.378159 -0.546884 0.245400 \n",
|
||
"35456 0.296300 0.296300 -1.128369 -0.696130 \n",
|
||
"14305 0.542404 0.542404 -1.098590 -0.867440 \n",
|
||
"\n",
|
||
" sentry_object absolute_magnitude hazardous id \n",
|
||
"2639 0.0 0.577785 -0.328347 3634614 \n",
|
||
"29138 0.0 0.412170 -0.328347 54143560 \n",
|
||
"36927 0.0 -1.009355 -0.328347 3836085 \n",
|
||
"61855 0.0 -0.836840 -0.328347 3769804 \n",
|
||
"15916 0.0 1.509367 -0.328347 3824978 \n",
|
||
"... ... ... ... ... \n",
|
||
"29491 0.0 2.233931 -0.328347 3827304 \n",
|
||
"18373 0.0 -0.509061 -0.328347 3735468 \n",
|
||
"25031 0.0 1.267846 -0.328347 3802041 \n",
|
||
"35456 0.0 -1.078361 -0.328347 3430406 \n",
|
||
"14305 0.0 -1.302631 -0.328347 3285300 \n",
|
||
"\n",
|
||
"[72668 rows x 8 columns]"
|
||
]
|
||
},
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
|
||
"preprocessed_df = pd.DataFrame(\n",
|
||
" preprocessing_result,\n",
|
||
" columns=pipeline_end.get_feature_names_out(),\n",
|
||
")\n",
|
||
"\n",
|
||
"preprocessed_df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Формирование набора моделей для классификации\n",
|
||
" logistic -- логистическая регрессия\n",
|
||
"\n",
|
||
"ridge -- гребневая регрессия\n",
|
||
"\n",
|
||
"decision_tree -- дерево решений\n",
|
||
"\n",
|
||
"knn -- k-ближайших соседей\n",
|
||
"\n",
|
||
"naive_bayes -- наивный Байесовский классификатор\n",
|
||
"\n",
|
||
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
|
||
"\n",
|
||
"random_forest -- метод случайного леса (набор деревьев решений)\n",
|
||
"\n",
|
||
"mlp -- многослойный персептрон (нейронная сеть)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
|
||
"\n",
|
||
"class_models = {\n",
|
||
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
|
||
" # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
|
||
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
|
||
" \"decision_tree\": {\n",
|
||
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
|
||
" },\n",
|
||
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
|
||
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
|
||
" \"gradient_boosting\": {\n",
|
||
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
|
||
" },\n",
|
||
" \"random_forest\": {\n",
|
||
" \"model\": ensemble.RandomForestClassifier(\n",
|
||
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
|
||
" )\n",
|
||
" },\n",
|
||
" \"mlp\": {\n",
|
||
" \"model\": neural_network.MLPClassifier(\n",
|
||
" hidden_layer_sizes=(7,),\n",
|
||
" max_iter=500,\n",
|
||
" early_stopping=True,\n",
|
||
" random_state=random_state,\n",
|
||
" )\n",
|
||
" },\n",
|
||
"}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Обучение моделей на обучающем наборе данных и оценка на тестовом"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Model: logistic\n",
|
||
"Model: ridge\n",
|
||
"Model: decision_tree\n",
|
||
"Model: knn\n",
|
||
"Model: naive_bayes\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Model: gradient_boosting\n",
|
||
"Model: random_forest\n",
|
||
"Model: mlp\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"from sklearn import metrics\n",
|
||
"\n",
|
||
"for model_name in class_models.keys():\n",
|
||
" print(f\"Model: {model_name}\")\n",
|
||
" model = class_models[model_name][\"model\"]\n",
|
||
"\n",
|
||
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
|
||
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
|
||
"\n",
|
||
" y_train_predict = model_pipeline.predict(X_train)\n",
|
||
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
|
||
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
|
||
"\n",
|
||
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
||
" class_models[model_name][\"probs\"] = y_test_probs\n",
|
||
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
||
"\n",
|
||
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
|
||
" y_test, y_test_probs\n",
|
||
" )\n",
|
||
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
|
||
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
|
||
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Сводная таблица оценок качества для использованных моделей классификации"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 1200x1000 with 16 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"\n",
|
||
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
|
||
"for index, key in enumerate(class_models.keys()):\n",
|
||
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
|
||
" disp = ConfusionMatrixDisplay(\n",
|
||
" confusion_matrix=c_matrix, display_labels=[\"hazardous\", \"safe\"]\n",
|
||
" ).plot(ax=ax.flat[index])\n",
|
||
" disp.ax_.set_title(key)\n",
|
||
"\n",
|
||
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"16400 - количество истинных положительных диагнозов (True Positives), где модель правильно определила объекты как \"hazardous\".\n",
|
||
"\n",
|
||
"1768 в некоторых моделях - количество ложных отрицательных диагнозов (False Negatives), где модель неправильно определила объекты, которые на самом деле принадлежат к классу \"hazardous\", но были отнесены к классу \"safe\". \n",
|
||
"\n",
|
||
"Исходя из значений True Positives и False Negatives, можно сказать, что модель имеет высокую точность при предсказании класса \"hazardous\". В принципе, уровень ложных отрицательных результатов в некоторых моделях (1768) говорит нам о том, что существует некотрое небольшое количество примеров, которые модель пропускает.\n",
|
||
"\n",
|
||
"Точность, полнота, верность (аккуратность), F-мера"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_371be_row0_col0, #T_371be_row0_col1, #T_371be_row0_col2, #T_371be_row0_col3, #T_371be_row1_col0, #T_371be_row1_col1, #T_371be_row1_col2, #T_371be_row1_col3, #T_371be_row2_col0, #T_371be_row2_col1, #T_371be_row2_col2, #T_371be_row2_col3, #T_371be_row3_col0, #T_371be_row3_col1, #T_371be_row3_col2, #T_371be_row3_col3, #T_371be_row7_col2, #T_371be_row7_col3 {\n",
|
||
" background-color: #a8db34;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_371be_row0_col4, #T_371be_row0_col5, #T_371be_row0_col6, #T_371be_row0_col7, #T_371be_row1_col4, #T_371be_row1_col5, #T_371be_row1_col6, #T_371be_row1_col7, #T_371be_row2_col4, #T_371be_row2_col5, #T_371be_row2_col6, #T_371be_row2_col7, #T_371be_row3_col4, #T_371be_row3_col5, #T_371be_row3_col6, #T_371be_row3_col7 {\n",
|
||
" background-color: #da5a6a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_371be_row4_col0 {\n",
|
||
" background-color: #86d549;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_371be_row4_col1 {\n",
|
||
" background-color: #77d153;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_371be_row4_col2 {\n",
|
||
" background-color: #63cb5f;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_371be_row4_col3 {\n",
|
||
" background-color: #4ac16d;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_371be_row4_col4 {\n",
|
||
" background-color: #c03a83;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_371be_row4_col5 {\n",
|
||
" background-color: #b32c8e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_371be_row4_col6 {\n",
|
||
" background-color: #c7427c;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_371be_row4_col7 {\n",
|
||
" background-color: #bd3786;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_371be_row5_col0, #T_371be_row5_col1, #T_371be_row5_col2, #T_371be_row5_col3, #T_371be_row6_col0, #T_371be_row6_col1, #T_371be_row6_col2, #T_371be_row6_col3 {\n",
|
||
" background-color: #26818e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_371be_row5_col4, #T_371be_row6_col4 {\n",
|
||
" background-color: #8004a8;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_371be_row5_col5, #T_371be_row6_col5 {\n",
|
||
" background-color: #7d03a8;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_371be_row5_col6, #T_371be_row5_col7, #T_371be_row6_col6, #T_371be_row6_col7, #T_371be_row7_col4, #T_371be_row7_col5 {\n",
|
||
" background-color: #4e02a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_371be_row7_col0 {\n",
|
||
" background-color: #25ac82;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_371be_row7_col1 {\n",
|
||
" background-color: #26ad81;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_371be_row7_col6 {\n",
|
||
" background-color: #ac2694;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_371be_row7_col7 {\n",
|
||
" background-color: #ad2793;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_371be\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_371be_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
||
" <th id=\"T_371be_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
||
" <th id=\"T_371be_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
||
" <th id=\"T_371be_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
||
" <th id=\"T_371be_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
||
" <th id=\"T_371be_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_371be_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
||
" <th id=\"T_371be_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_371be_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
|
||
" <td id=\"T_371be_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_371be_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
|
||
" <td id=\"T_371be_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_371be_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
|
||
" <td id=\"T_371be_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_371be_level0_row3\" class=\"row_heading level0 row3\" >gradient_boosting</th>\n",
|
||
" <td id=\"T_371be_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row3_col5\" class=\"data row3 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row3_col6\" class=\"data row3 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row3_col7\" class=\"data row3 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_371be_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
|
||
" <td id=\"T_371be_row4_col0\" class=\"data row4 col0\" >0.884596</td>\n",
|
||
" <td id=\"T_371be_row4_col1\" class=\"data row4 col1\" >0.826374</td>\n",
|
||
" <td id=\"T_371be_row4_col2\" class=\"data row4 col2\" >0.744627</td>\n",
|
||
" <td id=\"T_371be_row4_col3\" class=\"data row4 col3\" >0.638009</td>\n",
|
||
" <td id=\"T_371be_row4_col4\" class=\"data row4 col4\" >0.965693</td>\n",
|
||
" <td id=\"T_371be_row4_col5\" class=\"data row4 col5\" >0.951728</td>\n",
|
||
" <td id=\"T_371be_row4_col6\" class=\"data row4 col6\" >0.808599</td>\n",
|
||
" <td id=\"T_371be_row4_col7\" class=\"data row4 col7\" >0.720077</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_371be_level0_row5\" class=\"row_heading level0 row5\" >naive_bayes</th>\n",
|
||
" <td id=\"T_371be_row5_col0\" class=\"data row5 col0\" >0.000000</td>\n",
|
||
" <td id=\"T_371be_row5_col1\" class=\"data row5 col1\" >0.000000</td>\n",
|
||
" <td id=\"T_371be_row5_col2\" class=\"data row5 col2\" >0.000000</td>\n",
|
||
" <td id=\"T_371be_row5_col3\" class=\"data row5 col3\" >0.000000</td>\n",
|
||
" <td id=\"T_371be_row5_col4\" class=\"data row5 col4\" >0.902681</td>\n",
|
||
" <td id=\"T_371be_row5_col5\" class=\"data row5 col5\" >0.902686</td>\n",
|
||
" <td id=\"T_371be_row5_col6\" class=\"data row5 col6\" >0.000000</td>\n",
|
||
" <td id=\"T_371be_row5_col7\" class=\"data row5 col7\" >0.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_371be_level0_row6\" class=\"row_heading level0 row6\" >mlp</th>\n",
|
||
" <td id=\"T_371be_row6_col0\" class=\"data row6 col0\" >0.000000</td>\n",
|
||
" <td id=\"T_371be_row6_col1\" class=\"data row6 col1\" >0.000000</td>\n",
|
||
" <td id=\"T_371be_row6_col2\" class=\"data row6 col2\" >0.000000</td>\n",
|
||
" <td id=\"T_371be_row6_col3\" class=\"data row6 col3\" >0.000000</td>\n",
|
||
" <td id=\"T_371be_row6_col4\" class=\"data row6 col4\" >0.902681</td>\n",
|
||
" <td id=\"T_371be_row6_col5\" class=\"data row6 col5\" >0.902686</td>\n",
|
||
" <td id=\"T_371be_row6_col6\" class=\"data row6 col6\" >0.000000</td>\n",
|
||
" <td id=\"T_371be_row6_col7\" class=\"data row6 col7\" >0.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_371be_level0_row7\" class=\"row_heading level0 row7\" >ridge</th>\n",
|
||
" <td id=\"T_371be_row7_col0\" class=\"data row7 col0\" >0.415780</td>\n",
|
||
" <td id=\"T_371be_row7_col1\" class=\"data row7 col1\" >0.421253</td>\n",
|
||
" <td id=\"T_371be_row7_col2\" class=\"data row7 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row7_col3\" class=\"data row7 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_371be_row7_col4\" class=\"data row7 col4\" >0.863255</td>\n",
|
||
" <td id=\"T_371be_row7_col5\" class=\"data row7 col5\" >0.866303</td>\n",
|
||
" <td id=\"T_371be_row7_col6\" class=\"data row7 col6\" >0.587351</td>\n",
|
||
" <td id=\"T_371be_row7_col7\" class=\"data row7 col7\" >0.592791</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x1dd731d3fe0>"
|
||
]
|
||
},
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
||
" [\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" \"Accuracy_train\",\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_train\",\n",
|
||
" \"F1_test\",\n",
|
||
" ]\n",
|
||
"]\n",
|
||
"class_metrics.sort_values(\n",
|
||
" by=\"Accuracy_test\", ascending=False\n",
|
||
").style.background_gradient(\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Все модели в данной выборке, а именно логистическая регрессия, ридж-регрессия, дерево решений, KNN, наивный байесовский классификатор, градиентный бустинг, случайный лес и многослойный перцептрон (MLP) демонстрируют неплохие значения по всем метрикам на обучающих и тестовых наборах данных.\n",
|
||
"\n",
|
||
"Модели Naive Bayes и MLP не так эффективны по сравнению с другими, но в некоторых метриках показывают высокие результаты. \n",
|
||
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_9ba87_row0_col0, #T_9ba87_row0_col1, #T_9ba87_row1_col0, #T_9ba87_row1_col1, #T_9ba87_row2_col0, #T_9ba87_row2_col1, #T_9ba87_row3_col0, #T_9ba87_row3_col1 {\n",
|
||
" background-color: #a8db34;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_9ba87_row0_col2, #T_9ba87_row0_col3, #T_9ba87_row0_col4, #T_9ba87_row1_col2, #T_9ba87_row1_col3, #T_9ba87_row1_col4, #T_9ba87_row2_col2, #T_9ba87_row2_col3, #T_9ba87_row2_col4, #T_9ba87_row3_col2, #T_9ba87_row3_col3, #T_9ba87_row3_col4 {\n",
|
||
" background-color: #da5a6a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_9ba87_row4_col0, #T_9ba87_row6_col1, #T_9ba87_row7_col1 {\n",
|
||
" background-color: #26818e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_9ba87_row4_col1 {\n",
|
||
" background-color: #40bd72;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_9ba87_row4_col2 {\n",
|
||
" background-color: #d9586a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_9ba87_row4_col3, #T_9ba87_row6_col2 {\n",
|
||
" background-color: #a51f99;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_9ba87_row4_col4 {\n",
|
||
" background-color: #ae2892;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_9ba87_row5_col0 {\n",
|
||
" background-color: #4ac16d;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_9ba87_row5_col1 {\n",
|
||
" background-color: #5cc863;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_9ba87_row5_col2 {\n",
|
||
" background-color: #d14e72;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_9ba87_row5_col3 {\n",
|
||
" background-color: #ba3388;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_9ba87_row5_col4 {\n",
|
||
" background-color: #bb3488;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_9ba87_row6_col0, #T_9ba87_row7_col0 {\n",
|
||
" background-color: #1e9d89;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_9ba87_row6_col3, #T_9ba87_row6_col4, #T_9ba87_row7_col2, #T_9ba87_row7_col3, #T_9ba87_row7_col4 {\n",
|
||
" background-color: #4e02a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_9ba87\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_9ba87_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_9ba87_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
||
" <th id=\"T_9ba87_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
||
" <th id=\"T_9ba87_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
||
" <th id=\"T_9ba87_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_9ba87_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
|
||
" <td id=\"T_9ba87_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_9ba87_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_9ba87_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_9ba87_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_9ba87_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_9ba87_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
|
||
" <td id=\"T_9ba87_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_9ba87_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_9ba87_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_9ba87_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_9ba87_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_9ba87_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
|
||
" <td id=\"T_9ba87_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_9ba87_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_9ba87_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_9ba87_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_9ba87_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_9ba87_level0_row3\" class=\"row_heading level0 row3\" >gradient_boosting</th>\n",
|
||
" <td id=\"T_9ba87_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_9ba87_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_9ba87_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_9ba87_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_9ba87_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_9ba87_level0_row4\" class=\"row_heading level0 row4\" >ridge</th>\n",
|
||
" <td id=\"T_9ba87_row4_col0\" class=\"data row4 col0\" >0.866303</td>\n",
|
||
" <td id=\"T_9ba87_row4_col1\" class=\"data row4 col1\" >0.592791</td>\n",
|
||
" <td id=\"T_9ba87_row4_col2\" class=\"data row4 col2\" >0.995675</td>\n",
|
||
" <td id=\"T_9ba87_row4_col3\" class=\"data row4 col3\" >0.528180</td>\n",
|
||
" <td id=\"T_9ba87_row4_col4\" class=\"data row4 col4\" >0.599051</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_9ba87_level0_row5\" class=\"row_heading level0 row5\" >knn</th>\n",
|
||
" <td id=\"T_9ba87_row5_col0\" class=\"data row5 col0\" >0.951728</td>\n",
|
||
" <td id=\"T_9ba87_row5_col1\" class=\"data row5 col1\" >0.720077</td>\n",
|
||
" <td id=\"T_9ba87_row5_col2\" class=\"data row5 col2\" >0.953405</td>\n",
|
||
" <td id=\"T_9ba87_row5_col3\" class=\"data row5 col3\" >0.694141</td>\n",
|
||
" <td id=\"T_9ba87_row5_col4\" class=\"data row5 col4\" >0.701100</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_9ba87_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
|
||
" <td id=\"T_9ba87_row6_col0\" class=\"data row6 col0\" >0.902686</td>\n",
|
||
" <td id=\"T_9ba87_row6_col1\" class=\"data row6 col1\" >0.000000</td>\n",
|
||
" <td id=\"T_9ba87_row6_col2\" class=\"data row6 col2\" >0.766341</td>\n",
|
||
" <td id=\"T_9ba87_row6_col3\" class=\"data row6 col3\" >0.000000</td>\n",
|
||
" <td id=\"T_9ba87_row6_col4\" class=\"data row6 col4\" >0.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_9ba87_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
|
||
" <td id=\"T_9ba87_row7_col0\" class=\"data row7 col0\" >0.902686</td>\n",
|
||
" <td id=\"T_9ba87_row7_col1\" class=\"data row7 col1\" >0.000000</td>\n",
|
||
" <td id=\"T_9ba87_row7_col2\" class=\"data row7 col2\" >0.500000</td>\n",
|
||
" <td id=\"T_9ba87_row7_col3\" class=\"data row7 col3\" >0.000000</td>\n",
|
||
" <td id=\"T_9ba87_row7_col4\" class=\"data row7 col4\" >0.000000</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x1dd76e7ec00>"
|
||
]
|
||
},
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
||
" [\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_test\",\n",
|
||
" \"ROC_AUC_test\",\n",
|
||
" \"Cohen_kappa_test\",\n",
|
||
" \"MCC_test\",\n",
|
||
" ]\n",
|
||
"]\n",
|
||
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" subset=[\n",
|
||
" \"ROC_AUC_test\",\n",
|
||
" \"MCC_test\",\n",
|
||
" \"Cohen_kappa_test\",\n",
|
||
" ],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Схожий вывод можно сделать и для следующих метрик: Accuracy, F1, ROC AUC, Cohen's Kappa и MCC. Все модели, кроме Naive Bayes и MLP, указывают на хорошо-развитую способность к выделению классов"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'logistic'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
|
||
"\n",
|
||
"display(best_model)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Вывод данных с ошибкой предсказания для оценки"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'Error items count: 0'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>id</th>\n",
|
||
" <th>Predicted</th>\n",
|
||
" <th>name</th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>orbiting_body</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
"Empty DataFrame\n",
|
||
"Columns: [id, Predicted, name, est_diameter_min, est_diameter_max, relative_velocity, miss_distance, orbiting_body, sentry_object, absolute_magnitude, hazardous]\n",
|
||
"Index: []"
|
||
]
|
||
},
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"preprocessing_result = pipeline_end.transform(X_test)\n",
|
||
"preprocessed_df = pd.DataFrame(\n",
|
||
" preprocessing_result,\n",
|
||
" columns=pipeline_end.get_feature_names_out(),\n",
|
||
")\n",
|
||
"\n",
|
||
"y_pred = class_models[best_model][\"preds\"]\n",
|
||
"\n",
|
||
"error_index = y_test[y_test[\"hazardous\"] != y_pred].index.tolist()\n",
|
||
"display(f\"Error items count: {len(error_index)}\")\n",
|
||
"\n",
|
||
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
|
||
"error_df = X_test.loc[error_index].copy()\n",
|
||
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
|
||
"error_df.sort_index()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Пример использования обученной модели (конвейера) для предсказания\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 37,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>id</th>\n",
|
||
" <th>name</th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>orbiting_body</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>67305</th>\n",
|
||
" <td>3774018</td>\n",
|
||
" <td>(2017 HF1)</td>\n",
|
||
" <td>0.084053</td>\n",
|
||
" <td>0.187949</td>\n",
|
||
" <td>53291.016226</td>\n",
|
||
" <td>68625911.198806</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>22.5</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" id name est_diameter_min est_diameter_max \\\n",
|
||
"67305 3774018 (2017 HF1) 0.084053 0.187949 \n",
|
||
"\n",
|
||
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
|
||
"67305 53291.016226 68625911.198806 Earth False \n",
|
||
"\n",
|
||
" absolute_magnitude hazardous \n",
|
||
"67305 22.5 False "
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" <th>id</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>67305</th>\n",
|
||
" <td>-0.140818</td>\n",
|
||
" <td>-0.140818</td>\n",
|
||
" <td>0.207258</td>\n",
|
||
" <td>1.410653</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>-0.353797</td>\n",
|
||
" <td>-0.328347</td>\n",
|
||
" <td>3774018.0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" est_diameter_min est_diameter_max relative_velocity miss_distance \\\n",
|
||
"67305 -0.140818 -0.140818 0.207258 1.410653 \n",
|
||
"\n",
|
||
" sentry_object absolute_magnitude hazardous id \n",
|
||
"67305 0.0 -0.353797 -0.328347 3774018.0 "
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'predicted: False (proba: [9.99855425e-01 1.44575476e-04])'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'real: 0'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"model = class_models[best_model][\"pipeline\"]\n",
|
||
"\n",
|
||
"example_id = 67305\n",
|
||
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
|
||
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
|
||
"display(test)\n",
|
||
"display(test_preprocessed)\n",
|
||
"result_proba = model.predict_proba(test)[0]\n",
|
||
"result = model.predict(test)[0]\n",
|
||
"real = int(y_test.loc[example_id].values[0])\n",
|
||
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
|
||
"display(f\"real: {real}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Подбор гиперпараметров методом поиска по сетке "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 38,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
|
||
" _data = np.array(data, dtype=dtype, copy=copy,\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'model__criterion': 'gini',\n",
|
||
" 'model__max_depth': 5,\n",
|
||
" 'model__max_features': 'sqrt',\n",
|
||
" 'model__n_estimators': 50}"
|
||
]
|
||
},
|
||
"execution_count": 38,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.model_selection import GridSearchCV\n",
|
||
"\n",
|
||
"optimized_model_type = \"random_forest\"\n",
|
||
"\n",
|
||
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
|
||
"\n",
|
||
"param_grid = {\n",
|
||
" \"model__n_estimators\": [10, 50, 100],\n",
|
||
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
|
||
" \"model__max_depth\": [5, 7, 10],\n",
|
||
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
|
||
"}\n",
|
||
"\n",
|
||
"gs_optomizer = GridSearchCV(\n",
|
||
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
|
||
")\n",
|
||
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
|
||
"gs_optomizer.best_params_"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Обучение модели с новыми гиперпараметрами"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 48,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.pipeline import Pipeline\n",
|
||
"from sklearn.preprocessing import StandardScaler\n",
|
||
"from sklearn.compose import ColumnTransformer\n",
|
||
"from sklearn.ensemble import RandomForestClassifier\n",
|
||
"import numpy as np\n",
|
||
"from sklearn import metrics\n",
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"\n",
|
||
"# Определяем числовые признаки\n",
|
||
"numeric_features = X_train.select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
|
||
"\n",
|
||
"# Установка random_state\n",
|
||
"random_state = 42\n",
|
||
"\n",
|
||
"# Определение трансформера\n",
|
||
"pipeline_end = ColumnTransformer([\n",
|
||
" ('numeric', StandardScaler(), numeric_features),\n",
|
||
" # Добавьте другие трансформеры, если требуется\n",
|
||
"])\n",
|
||
"\n",
|
||
"# Объявление модели\n",
|
||
"optimized_model = RandomForestClassifier(\n",
|
||
" random_state=random_state,\n",
|
||
" criterion=\"gini\",\n",
|
||
" max_depth=5,\n",
|
||
" max_features=\"sqrt\",\n",
|
||
" n_estimators=50,\n",
|
||
")\n",
|
||
"\n",
|
||
"# Создание пайплайна с корректными шагами\n",
|
||
"result = {}\n",
|
||
"\n",
|
||
"# Обучение модели\n",
|
||
"result[\"pipeline\"] = Pipeline([\n",
|
||
" (\"pipeline\", pipeline_end),\n",
|
||
" (\"model\", optimized_model)\n",
|
||
"]).fit(X_train, y_train.values.ravel())\n",
|
||
"\n",
|
||
"# Прогнозирование и расчет метрик\n",
|
||
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
|
||
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
|
||
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
|
||
"\n",
|
||
"# Метрики для оценки модели\n",
|
||
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
|
||
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
|
||
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
|
||
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
|
||
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
|
||
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
|
||
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
|
||
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
|
||
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
|
||
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
|
||
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
|
||
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Формирование данных для оценки старой и новой версии модели"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 49,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
|
||
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
||
" data=class_models[optimized_model_type]\n",
|
||
")\n",
|
||
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
||
" data=result\n",
|
||
")\n",
|
||
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
|
||
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Оценка параметров старой и новой модели"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 50,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_103c9_row0_col0, #T_103c9_row0_col1, #T_103c9_row0_col2, #T_103c9_row0_col3 {\n",
|
||
" background-color: #a8db34;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_103c9_row0_col4, #T_103c9_row0_col5, #T_103c9_row0_col6, #T_103c9_row0_col7 {\n",
|
||
" background-color: #da5a6a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_103c9_row1_col0, #T_103c9_row1_col1, #T_103c9_row1_col2, #T_103c9_row1_col3 {\n",
|
||
" background-color: #26818e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_103c9_row1_col4, #T_103c9_row1_col5, #T_103c9_row1_col6, #T_103c9_row1_col7 {\n",
|
||
" background-color: #4e02a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_103c9\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_103c9_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
||
" <th id=\"T_103c9_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
||
" <th id=\"T_103c9_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
||
" <th id=\"T_103c9_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
||
" <th id=\"T_103c9_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
||
" <th id=\"T_103c9_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_103c9_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
||
" <th id=\"T_103c9_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th class=\"index_name level0\" >Name</th>\n",
|
||
" <th class=\"blank col0\" > </th>\n",
|
||
" <th class=\"blank col1\" > </th>\n",
|
||
" <th class=\"blank col2\" > </th>\n",
|
||
" <th class=\"blank col3\" > </th>\n",
|
||
" <th class=\"blank col4\" > </th>\n",
|
||
" <th class=\"blank col5\" > </th>\n",
|
||
" <th class=\"blank col6\" > </th>\n",
|
||
" <th class=\"blank col7\" > </th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_103c9_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
||
" <td id=\"T_103c9_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_103c9_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_103c9_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_103c9_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_103c9_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
||
" <td id=\"T_103c9_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
|
||
" <td id=\"T_103c9_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
||
" <td id=\"T_103c9_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_103c9_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
||
" <td id=\"T_103c9_row1_col0\" class=\"data row1 col0\" >0.833191</td>\n",
|
||
" <td id=\"T_103c9_row1_col1\" class=\"data row1 col1\" >0.862500</td>\n",
|
||
" <td id=\"T_103c9_row1_col2\" class=\"data row1 col2\" >0.138433</td>\n",
|
||
" <td id=\"T_103c9_row1_col3\" class=\"data row1 col3\" >0.156109</td>\n",
|
||
" <td id=\"T_103c9_row1_col4\" class=\"data row1 col4\" >0.913456</td>\n",
|
||
" <td id=\"T_103c9_row1_col5\" class=\"data row1 col5\" >0.915456</td>\n",
|
||
" <td id=\"T_103c9_row1_col6\" class=\"data row1 col6\" >0.237420</td>\n",
|
||
" <td id=\"T_103c9_row1_col7\" class=\"data row1 col7\" >0.264368</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x1dd76b55010>"
|
||
]
|
||
},
|
||
"execution_count": 50,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"optimized_metrics[\n",
|
||
" [\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" \"Accuracy_train\",\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_train\",\n",
|
||
" \"F1_test\",\n",
|
||
" ]\n",
|
||
"].style.background_gradient(\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Precision_train\",\n",
|
||
" \"Precision_test\",\n",
|
||
" \"Recall_train\",\n",
|
||
" \"Recall_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 51,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style type=\"text/css\">\n",
|
||
"#T_6af3a_row0_col0, #T_6af3a_row0_col1 {\n",
|
||
" background-color: #a8db34;\n",
|
||
" color: #000000;\n",
|
||
"}\n",
|
||
"#T_6af3a_row0_col2, #T_6af3a_row0_col3, #T_6af3a_row0_col4 {\n",
|
||
" background-color: #da5a6a;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_6af3a_row1_col0, #T_6af3a_row1_col1 {\n",
|
||
" background-color: #26818e;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"#T_6af3a_row1_col2, #T_6af3a_row1_col3, #T_6af3a_row1_col4 {\n",
|
||
" background-color: #4e02a2;\n",
|
||
" color: #f1f1f1;\n",
|
||
"}\n",
|
||
"</style>\n",
|
||
"<table id=\"T_6af3a\">\n",
|
||
" <thead>\n",
|
||
" <tr>\n",
|
||
" <th class=\"blank level0\" > </th>\n",
|
||
" <th id=\"T_6af3a_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
||
" <th id=\"T_6af3a_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
||
" <th id=\"T_6af3a_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
||
" <th id=\"T_6af3a_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
||
" <th id=\"T_6af3a_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th class=\"index_name level0\" >Name</th>\n",
|
||
" <th class=\"blank col0\" > </th>\n",
|
||
" <th class=\"blank col1\" > </th>\n",
|
||
" <th class=\"blank col2\" > </th>\n",
|
||
" <th class=\"blank col3\" > </th>\n",
|
||
" <th class=\"blank col4\" > </th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_6af3a_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
||
" <td id=\"T_6af3a_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
||
" <td id=\"T_6af3a_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
||
" <td id=\"T_6af3a_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
||
" <td id=\"T_6af3a_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
||
" <td id=\"T_6af3a_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th id=\"T_6af3a_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
||
" <td id=\"T_6af3a_row1_col0\" class=\"data row1 col0\" >0.915456</td>\n",
|
||
" <td id=\"T_6af3a_row1_col1\" class=\"data row1 col1\" >0.264368</td>\n",
|
||
" <td id=\"T_6af3a_row1_col2\" class=\"data row1 col2\" >0.927493</td>\n",
|
||
" <td id=\"T_6af3a_row1_col3\" class=\"data row1 col3\" >0.241751</td>\n",
|
||
" <td id=\"T_6af3a_row1_col4\" class=\"data row1 col4\" >0.345694</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n"
|
||
],
|
||
"text/plain": [
|
||
"<pandas.io.formats.style.Styler at 0x1dd76b54c50>"
|
||
]
|
||
},
|
||
"execution_count": 51,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"optimized_metrics[\n",
|
||
" [\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_test\",\n",
|
||
" \"ROC_AUC_test\",\n",
|
||
" \"Cohen_kappa_test\",\n",
|
||
" \"MCC_test\",\n",
|
||
" ]\n",
|
||
"].style.background_gradient(\n",
|
||
" cmap=\"plasma\",\n",
|
||
" low=0.3,\n",
|
||
" high=1,\n",
|
||
" subset=[\n",
|
||
" \"ROC_AUC_test\",\n",
|
||
" \"MCC_test\",\n",
|
||
" \"Cohen_kappa_test\",\n",
|
||
" ],\n",
|
||
").background_gradient(\n",
|
||
" cmap=\"viridis\",\n",
|
||
" low=1,\n",
|
||
" high=0.3,\n",
|
||
" subset=[\n",
|
||
" \"Accuracy_test\",\n",
|
||
" \"F1_test\",\n",
|
||
" ],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 53,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 1000x400 with 4 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
|
||
")\n",
|
||
"\n",
|
||
"for index in range(0, len(optimized_metrics)):\n",
|
||
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
|
||
" disp = ConfusionMatrixDisplay(\n",
|
||
" confusion_matrix=c_matrix, display_labels=[\"hazardous\", \"safe\"]\n",
|
||
" ).plot(ax=ax.flat[index])\n",
|
||
"\n",
|
||
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"В желтых квадрате мы наблюдаем значение 16400, что обозначает количество правильно классифицированных объектов, отнесенных к классу \"hazardsous\". Это свидетельствует о том, что модель успешно идентифицирует объекты этого класса, минимизируя количество ложных положительных срабатываний.\n",
|
||
"\n",
|
||
"В фиолетвом квадрате значение 276 указывает на количество правильно классифицированных объектов, отнесенных к классу \"More\". Это является показателем не такой высокой точности модели в определении объектов данного класса."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Определение достижимого уровня качества модели для второй задачи (задача регрессии)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Index(['id', 'name', 'est_diameter_min', 'est_diameter_max',\n",
|
||
" 'relative_velocity', 'miss_distance', 'orbiting_body', 'sentry_object',\n",
|
||
" 'absolute_magnitude', 'hazardous'],\n",
|
||
" dtype='object')\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>id</th>\n",
|
||
" <th>name</th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" <th>orbiting_body</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>2162635</td>\n",
|
||
" <td>162635 (2000 SS164)</td>\n",
|
||
" <td>1.198271</td>\n",
|
||
" <td>2.679415</td>\n",
|
||
" <td>13569.249224</td>\n",
|
||
" <td>5.483974e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>16.73</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>2277475</td>\n",
|
||
" <td>277475 (2005 WK4)</td>\n",
|
||
" <td>0.265800</td>\n",
|
||
" <td>0.594347</td>\n",
|
||
" <td>73588.726663</td>\n",
|
||
" <td>6.143813e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>20.00</td>\n",
|
||
" <td>True</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>2512244</td>\n",
|
||
" <td>512244 (2015 YE18)</td>\n",
|
||
" <td>0.722030</td>\n",
|
||
" <td>1.614507</td>\n",
|
||
" <td>114258.692129</td>\n",
|
||
" <td>4.979872e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>17.83</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>3596030</td>\n",
|
||
" <td>(2012 BV13)</td>\n",
|
||
" <td>0.096506</td>\n",
|
||
" <td>0.215794</td>\n",
|
||
" <td>24764.303138</td>\n",
|
||
" <td>2.543497e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>22.20</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>3667127</td>\n",
|
||
" <td>(2014 GE35)</td>\n",
|
||
" <td>0.255009</td>\n",
|
||
" <td>0.570217</td>\n",
|
||
" <td>42737.733765</td>\n",
|
||
" <td>4.627557e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>20.09</td>\n",
|
||
" <td>True</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90831</th>\n",
|
||
" <td>3763337</td>\n",
|
||
" <td>(2016 VX1)</td>\n",
|
||
" <td>0.026580</td>\n",
|
||
" <td>0.059435</td>\n",
|
||
" <td>52078.886692</td>\n",
|
||
" <td>1.230039e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>25.00</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90832</th>\n",
|
||
" <td>3837603</td>\n",
|
||
" <td>(2019 AD3)</td>\n",
|
||
" <td>0.016771</td>\n",
|
||
" <td>0.037501</td>\n",
|
||
" <td>46114.605073</td>\n",
|
||
" <td>5.432121e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>26.00</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90833</th>\n",
|
||
" <td>54017201</td>\n",
|
||
" <td>(2020 JP3)</td>\n",
|
||
" <td>0.031956</td>\n",
|
||
" <td>0.071456</td>\n",
|
||
" <td>7566.807732</td>\n",
|
||
" <td>2.840077e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>24.60</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90834</th>\n",
|
||
" <td>54115824</td>\n",
|
||
" <td>(2021 CN5)</td>\n",
|
||
" <td>0.007321</td>\n",
|
||
" <td>0.016370</td>\n",
|
||
" <td>69199.154484</td>\n",
|
||
" <td>6.869206e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>27.80</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>90835</th>\n",
|
||
" <td>54205447</td>\n",
|
||
" <td>(2021 TW7)</td>\n",
|
||
" <td>0.039862</td>\n",
|
||
" <td>0.089133</td>\n",
|
||
" <td>27024.455553</td>\n",
|
||
" <td>5.977213e+07</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>24.12</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>90836 rows × 10 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" id name est_diameter_min est_diameter_max \\\n",
|
||
"0 2162635 162635 (2000 SS164) 1.198271 2.679415 \n",
|
||
"1 2277475 277475 (2005 WK4) 0.265800 0.594347 \n",
|
||
"2 2512244 512244 (2015 YE18) 0.722030 1.614507 \n",
|
||
"3 3596030 (2012 BV13) 0.096506 0.215794 \n",
|
||
"4 3667127 (2014 GE35) 0.255009 0.570217 \n",
|
||
"... ... ... ... ... \n",
|
||
"90831 3763337 (2016 VX1) 0.026580 0.059435 \n",
|
||
"90832 3837603 (2019 AD3) 0.016771 0.037501 \n",
|
||
"90833 54017201 (2020 JP3) 0.031956 0.071456 \n",
|
||
"90834 54115824 (2021 CN5) 0.007321 0.016370 \n",
|
||
"90835 54205447 (2021 TW7) 0.039862 0.089133 \n",
|
||
"\n",
|
||
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
|
||
"0 13569.249224 5.483974e+07 Earth False \n",
|
||
"1 73588.726663 6.143813e+07 Earth False \n",
|
||
"2 114258.692129 4.979872e+07 Earth False \n",
|
||
"3 24764.303138 2.543497e+07 Earth False \n",
|
||
"4 42737.733765 4.627557e+07 Earth False \n",
|
||
"... ... ... ... ... \n",
|
||
"90831 52078.886692 1.230039e+07 Earth False \n",
|
||
"90832 46114.605073 5.432121e+07 Earth False \n",
|
||
"90833 7566.807732 2.840077e+07 Earth False \n",
|
||
"90834 69199.154484 6.869206e+07 Earth False \n",
|
||
"90835 27024.455553 5.977213e+07 Earth False \n",
|
||
"\n",
|
||
" absolute_magnitude hazardous \n",
|
||
"0 16.73 False \n",
|
||
"1 20.00 True \n",
|
||
"2 17.83 False \n",
|
||
"3 22.20 False \n",
|
||
"4 20.09 True \n",
|
||
"... ... ... \n",
|
||
"90831 25.00 False \n",
|
||
"90832 26.00 False \n",
|
||
"90833 24.60 False \n",
|
||
"90834 27.80 False \n",
|
||
"90835 24.12 False \n",
|
||
"\n",
|
||
"[90836 rows x 10 columns]"
|
||
]
|
||
},
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import seaborn as sns\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"from sklearn import set_config\n",
|
||
"\n",
|
||
"random_state=42\n",
|
||
"set_config(transform_output=\"pandas\")\n",
|
||
"df = pd.read_csv(\".//static//csv//neo.csv\")\n",
|
||
"print(df.columns)\n",
|
||
"df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи регрессии "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>id</th>\n",
|
||
" <th>name</th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>orbiting_body</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>35538</th>\n",
|
||
" <td>3826685</td>\n",
|
||
" <td>(2018 PR10)</td>\n",
|
||
" <td>0.038420</td>\n",
|
||
" <td>0.085909</td>\n",
|
||
" <td>91103.489666</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>24.20</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>40393</th>\n",
|
||
" <td>2277830</td>\n",
|
||
" <td>277830 (2006 HR29)</td>\n",
|
||
" <td>0.192555</td>\n",
|
||
" <td>0.430566</td>\n",
|
||
" <td>28359.611312</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>20.70</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>58540</th>\n",
|
||
" <td>3638201</td>\n",
|
||
" <td>(2013 HT25)</td>\n",
|
||
" <td>0.004619</td>\n",
|
||
" <td>0.010329</td>\n",
|
||
" <td>107351.426865</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>28.80</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>61670</th>\n",
|
||
" <td>3836282</td>\n",
|
||
" <td>(2018 WR)</td>\n",
|
||
" <td>0.015295</td>\n",
|
||
" <td>0.034201</td>\n",
|
||
" <td>21423.536884</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>26.20</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>11435</th>\n",
|
||
" <td>3802002</td>\n",
|
||
" <td>(2018 FU1)</td>\n",
|
||
" <td>0.011603</td>\n",
|
||
" <td>0.025944</td>\n",
|
||
" <td>69856.053840</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>26.80</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>6265</th>\n",
|
||
" <td>2530151</td>\n",
|
||
" <td>530151 (2011 AW55)</td>\n",
|
||
" <td>0.211132</td>\n",
|
||
" <td>0.472106</td>\n",
|
||
" <td>88209.754856</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>20.50</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>54886</th>\n",
|
||
" <td>3831736</td>\n",
|
||
" <td>(2018 TD5)</td>\n",
|
||
" <td>0.035039</td>\n",
|
||
" <td>0.078350</td>\n",
|
||
" <td>58758.452153</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>24.40</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>76820</th>\n",
|
||
" <td>2512234</td>\n",
|
||
" <td>512234 (2015 VO66)</td>\n",
|
||
" <td>0.211132</td>\n",
|
||
" <td>0.472106</td>\n",
|
||
" <td>52355.509176</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>20.50</td>\n",
|
||
" <td>True</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>860</th>\n",
|
||
" <td>54054466</td>\n",
|
||
" <td>(2020 SG1)</td>\n",
|
||
" <td>0.282199</td>\n",
|
||
" <td>0.631015</td>\n",
|
||
" <td>50527.379563</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>19.87</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>15795</th>\n",
|
||
" <td>3773929</td>\n",
|
||
" <td>(2017 GL7)</td>\n",
|
||
" <td>0.075258</td>\n",
|
||
" <td>0.168283</td>\n",
|
||
" <td>22527.647871</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>22.74</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>72668 rows × 9 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" id name est_diameter_min est_diameter_max \\\n",
|
||
"35538 3826685 (2018 PR10) 0.038420 0.085909 \n",
|
||
"40393 2277830 277830 (2006 HR29) 0.192555 0.430566 \n",
|
||
"58540 3638201 (2013 HT25) 0.004619 0.010329 \n",
|
||
"61670 3836282 (2018 WR) 0.015295 0.034201 \n",
|
||
"11435 3802002 (2018 FU1) 0.011603 0.025944 \n",
|
||
"... ... ... ... ... \n",
|
||
"6265 2530151 530151 (2011 AW55) 0.211132 0.472106 \n",
|
||
"54886 3831736 (2018 TD5) 0.035039 0.078350 \n",
|
||
"76820 2512234 512234 (2015 VO66) 0.211132 0.472106 \n",
|
||
"860 54054466 (2020 SG1) 0.282199 0.631015 \n",
|
||
"15795 3773929 (2017 GL7) 0.075258 0.168283 \n",
|
||
"\n",
|
||
" relative_velocity orbiting_body sentry_object absolute_magnitude \\\n",
|
||
"35538 91103.489666 Earth False 24.20 \n",
|
||
"40393 28359.611312 Earth False 20.70 \n",
|
||
"58540 107351.426865 Earth False 28.80 \n",
|
||
"61670 21423.536884 Earth False 26.20 \n",
|
||
"11435 69856.053840 Earth False 26.80 \n",
|
||
"... ... ... ... ... \n",
|
||
"6265 88209.754856 Earth False 20.50 \n",
|
||
"54886 58758.452153 Earth False 24.40 \n",
|
||
"76820 52355.509176 Earth False 20.50 \n",
|
||
"860 50527.379563 Earth False 19.87 \n",
|
||
"15795 22527.647871 Earth False 22.74 \n",
|
||
"\n",
|
||
" hazardous \n",
|
||
"35538 False \n",
|
||
"40393 False \n",
|
||
"58540 False \n",
|
||
"61670 False \n",
|
||
"11435 False \n",
|
||
"... ... \n",
|
||
"6265 False \n",
|
||
"54886 False \n",
|
||
"76820 True \n",
|
||
"860 False \n",
|
||
"15795 False \n",
|
||
"\n",
|
||
"[72668 rows x 9 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>35538</th>\n",
|
||
" <td>6.350550e+07</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>40393</th>\n",
|
||
" <td>2.868167e+07</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>58540</th>\n",
|
||
" <td>5.388098e+04</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>61670</th>\n",
|
||
" <td>5.103884e+07</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>11435</th>\n",
|
||
" <td>7.360836e+07</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>6265</th>\n",
|
||
" <td>4.034289e+07</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>54886</th>\n",
|
||
" <td>4.389994e+06</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>76820</th>\n",
|
||
" <td>4.380532e+07</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>860</th>\n",
|
||
" <td>5.837007e+07</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>15795</th>\n",
|
||
" <td>2.281469e+07</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>72668 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" miss_distance\n",
|
||
"35538 6.350550e+07\n",
|
||
"40393 2.868167e+07\n",
|
||
"58540 5.388098e+04\n",
|
||
"61670 5.103884e+07\n",
|
||
"11435 7.360836e+07\n",
|
||
"... ...\n",
|
||
"6265 4.034289e+07\n",
|
||
"54886 4.389994e+06\n",
|
||
"76820 4.380532e+07\n",
|
||
"860 5.837007e+07\n",
|
||
"15795 2.281469e+07\n",
|
||
"\n",
|
||
"[72668 rows x 1 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>id</th>\n",
|
||
" <th>name</th>\n",
|
||
" <th>est_diameter_min</th>\n",
|
||
" <th>est_diameter_max</th>\n",
|
||
" <th>relative_velocity</th>\n",
|
||
" <th>orbiting_body</th>\n",
|
||
" <th>sentry_object</th>\n",
|
||
" <th>absolute_magnitude</th>\n",
|
||
" <th>hazardous</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>20406</th>\n",
|
||
" <td>3943344</td>\n",
|
||
" <td>(2019 YT1)</td>\n",
|
||
" <td>0.024241</td>\n",
|
||
" <td>0.054205</td>\n",
|
||
" <td>22148.962596</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>25.20</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>74443</th>\n",
|
||
" <td>3879239</td>\n",
|
||
" <td>(2019 US)</td>\n",
|
||
" <td>0.012722</td>\n",
|
||
" <td>0.028447</td>\n",
|
||
" <td>26477.211836</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>26.60</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>74306</th>\n",
|
||
" <td>3879244</td>\n",
|
||
" <td>(2019 UU)</td>\n",
|
||
" <td>0.013322</td>\n",
|
||
" <td>0.029788</td>\n",
|
||
" <td>33770.201397</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>26.50</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>45943</th>\n",
|
||
" <td>2481965</td>\n",
|
||
" <td>481965 (2009 EB1)</td>\n",
|
||
" <td>0.193444</td>\n",
|
||
" <td>0.432554</td>\n",
|
||
" <td>43599.575296</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>20.69</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>62859</th>\n",
|
||
" <td>3789471</td>\n",
|
||
" <td>(2017 WJ1)</td>\n",
|
||
" <td>0.044112</td>\n",
|
||
" <td>0.098637</td>\n",
|
||
" <td>36398.080883</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>23.90</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>51634</th>\n",
|
||
" <td>3694131</td>\n",
|
||
" <td>(2014 UF56)</td>\n",
|
||
" <td>0.008801</td>\n",
|
||
" <td>0.019681</td>\n",
|
||
" <td>57414.305699</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>27.40</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>85083</th>\n",
|
||
" <td>54235475</td>\n",
|
||
" <td>(2022 AG1)</td>\n",
|
||
" <td>0.024920</td>\n",
|
||
" <td>0.055724</td>\n",
|
||
" <td>50882.935767</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>25.14</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>38905</th>\n",
|
||
" <td>3775176</td>\n",
|
||
" <td>(2017 LD)</td>\n",
|
||
" <td>0.008405</td>\n",
|
||
" <td>0.018795</td>\n",
|
||
" <td>24954.754212</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>27.50</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>16144</th>\n",
|
||
" <td>2434734</td>\n",
|
||
" <td>434734 (2006 FX)</td>\n",
|
||
" <td>0.265800</td>\n",
|
||
" <td>0.594347</td>\n",
|
||
" <td>57455.404666</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>20.00</td>\n",
|
||
" <td>True</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>54508</th>\n",
|
||
" <td>3170208</td>\n",
|
||
" <td>(2003 YG136)</td>\n",
|
||
" <td>0.023150</td>\n",
|
||
" <td>0.051765</td>\n",
|
||
" <td>72602.093427</td>\n",
|
||
" <td>Earth</td>\n",
|
||
" <td>False</td>\n",
|
||
" <td>25.30</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>18168 rows × 9 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" id name est_diameter_min est_diameter_max \\\n",
|
||
"20406 3943344 (2019 YT1) 0.024241 0.054205 \n",
|
||
"74443 3879239 (2019 US) 0.012722 0.028447 \n",
|
||
"74306 3879244 (2019 UU) 0.013322 0.029788 \n",
|
||
"45943 2481965 481965 (2009 EB1) 0.193444 0.432554 \n",
|
||
"62859 3789471 (2017 WJ1) 0.044112 0.098637 \n",
|
||
"... ... ... ... ... \n",
|
||
"51634 3694131 (2014 UF56) 0.008801 0.019681 \n",
|
||
"85083 54235475 (2022 AG1) 0.024920 0.055724 \n",
|
||
"38905 3775176 (2017 LD) 0.008405 0.018795 \n",
|
||
"16144 2434734 434734 (2006 FX) 0.265800 0.594347 \n",
|
||
"54508 3170208 (2003 YG136) 0.023150 0.051765 \n",
|
||
"\n",
|
||
" relative_velocity orbiting_body sentry_object absolute_magnitude \\\n",
|
||
"20406 22148.962596 Earth False 25.20 \n",
|
||
"74443 26477.211836 Earth False 26.60 \n",
|
||
"74306 33770.201397 Earth False 26.50 \n",
|
||
"45943 43599.575296 Earth False 20.69 \n",
|
||
"62859 36398.080883 Earth False 23.90 \n",
|
||
"... ... ... ... ... \n",
|
||
"51634 57414.305699 Earth False 27.40 \n",
|
||
"85083 50882.935767 Earth False 25.14 \n",
|
||
"38905 24954.754212 Earth False 27.50 \n",
|
||
"16144 57455.404666 Earth False 20.00 \n",
|
||
"54508 72602.093427 Earth False 25.30 \n",
|
||
"\n",
|
||
" hazardous \n",
|
||
"20406 False \n",
|
||
"74443 False \n",
|
||
"74306 False \n",
|
||
"45943 False \n",
|
||
"62859 False \n",
|
||
"... ... \n",
|
||
"51634 False \n",
|
||
"85083 False \n",
|
||
"38905 False \n",
|
||
"16144 True \n",
|
||
"54508 False \n",
|
||
"\n",
|
||
"[18168 rows x 9 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>miss_distance</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>20406</th>\n",
|
||
" <td>5.028574e+07</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>74443</th>\n",
|
||
" <td>1.683201e+06</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>74306</th>\n",
|
||
" <td>3.943220e+06</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>45943</th>\n",
|
||
" <td>7.346837e+07</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>62859</th>\n",
|
||
" <td>6.352916e+07</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>51634</th>\n",
|
||
" <td>1.987273e+07</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>85083</th>\n",
|
||
" <td>3.119646e+07</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>38905</th>\n",
|
||
" <td>1.111942e+07</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>16144</th>\n",
|
||
" <td>8.501684e+06</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>54508</th>\n",
|
||
" <td>4.624727e+07</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>18168 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" miss_distance\n",
|
||
"20406 5.028574e+07\n",
|
||
"74443 1.683201e+06\n",
|
||
"74306 3.943220e+06\n",
|
||
"45943 7.346837e+07\n",
|
||
"62859 6.352916e+07\n",
|
||
"... ...\n",
|
||
"51634 1.987273e+07\n",
|
||
"85083 3.119646e+07\n",
|
||
"38905 1.111942e+07\n",
|
||
"16144 8.501684e+06\n",
|
||
"54508 4.624727e+07\n",
|
||
"\n",
|
||
"[18168 rows x 1 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from typing import Tuple\n",
|
||
"import pandas as pd\n",
|
||
"from pandas import DataFrame\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"\n",
|
||
"def split_into_train_test(\n",
|
||
" df_input: DataFrame,\n",
|
||
" target_colname: str = \"miss_distance\",\n",
|
||
" frac_train: float = 0.8,\n",
|
||
" random_state: int = None,\n",
|
||
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
||
" \n",
|
||
" if not (0 < frac_train < 1):\n",
|
||
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
|
||
" \n",
|
||
" # Проверка наличия целевого признака\n",
|
||
" if target_colname not in df_input.columns:\n",
|
||
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
|
||
" \n",
|
||
" # Разделяем данные на признаки и целевую переменную\n",
|
||
" X = df_input.drop(columns=[target_colname]) # Признаки\n",
|
||
" y = df_input[[target_colname]] # Целевая переменная\n",
|
||
"\n",
|
||
" # Разделяем данные на обучающую и тестовую выборки\n",
|
||
" X_train, X_test, y_train, y_test = train_test_split(\n",
|
||
" X, y,\n",
|
||
" test_size=(1.0 - frac_train),\n",
|
||
" random_state=random_state\n",
|
||
" )\n",
|
||
" \n",
|
||
" return X_train, X_test, y_train, y_test\n",
|
||
"\n",
|
||
"# Применение функции для разделения данных\n",
|
||
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
|
||
" df, \n",
|
||
" target_colname=\"miss_distance\", \n",
|
||
" frac_train=0.8, \n",
|
||
" random_state=42\n",
|
||
")\n",
|
||
"\n",
|
||
"# Для отображения результатов\n",
|
||
"display(\"X_train\", X_train)\n",
|
||
"display(\"y_train\", y_train)\n",
|
||
"\n",
|
||
"display(\"X_test\", X_test)\n",
|
||
"display(\"y_test\", y_test)\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Формирование конвейера для решения задачи регрессии"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"from sklearn.base import BaseEstimator, TransformerMixin\n",
|
||
"from sklearn.compose import ColumnTransformer\n",
|
||
"from sklearn.preprocessing import StandardScaler\n",
|
||
"from sklearn.impute import SimpleImputer\n",
|
||
"from sklearn.pipeline import Pipeline\n",
|
||
"from sklearn.preprocessing import OneHotEncoder\n",
|
||
"from sklearn.ensemble import RandomForestRegressor # Пример регрессионной модели\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"from sklearn.pipeline import make_pipeline\n",
|
||
"\n",
|
||
"class EarthObjectsFeatures(BaseEstimator, TransformerMixin):\n",
|
||
" def __init__(self):\n",
|
||
" pass\n",
|
||
" \n",
|
||
" def fit(self, X, y=None):\n",
|
||
" return self\n",
|
||
"\n",
|
||
" def transform(self, X, y=None):\n",
|
||
" X[\"Length_to_Width_Ratio\"] = X[\"x\"] / X[\"y\"]\n",
|
||
" return X\n",
|
||
"\n",
|
||
" def get_feature_names_out(self, features_in):\n",
|
||
" return np.append(features_in, [\"Length_to_Width_Ratio\"], axis=0)\n",
|
||
"\n",
|
||
"# Указываем столбцы, которые нужно удалить и обрабатывать\n",
|
||
"columns_to_drop = [\"name\", \"orbiting_body\"]\n",
|
||
"num_columns = [\"est_diameter_min\", \"est_diameter_max\",\n",
|
||
" \"relative_velocity\", \"sentry_object\",\n",
|
||
" \"absolute_magnitude\", \"hazardous\"]\n",
|
||
"cat_columns = [] \n",
|
||
"\n",
|
||
"# Определяем предобработку для численных данных\n",
|
||
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
||
"num_scaler = StandardScaler()\n",
|
||
"preprocessing_num = Pipeline(\n",
|
||
" [\n",
|
||
" (\"imputer\", num_imputer),\n",
|
||
" (\"scaler\", num_scaler),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"# Определяем предобработку для категориальных данных\n",
|
||
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
||
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
||
"preprocessing_cat = Pipeline(\n",
|
||
" [\n",
|
||
" (\"imputer\", cat_imputer),\n",
|
||
" (\"encoder\", cat_encoder),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"# Подготовка признаков с использованием ColumnTransformer\n",
|
||
"features_preprocessing = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
|
||
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\"\n",
|
||
")\n",
|
||
"\n",
|
||
"# Удаление нежелательных столбцов\n",
|
||
"drop_columns = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\",\n",
|
||
")\n",
|
||
"\n",
|
||
"# Постобработка признаков\n",
|
||
"features_postprocessing = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=False,\n",
|
||
" transformers=[\n",
|
||
" (\"preprocessing_cat\", preprocessing_cat, [\"Cabin_type\"]), \n",
|
||
" ],\n",
|
||
" remainder=\"passthrough\",\n",
|
||
")\n",
|
||
"\n",
|
||
"# Создание окончательного конвейера\n",
|
||
"pipeline = Pipeline(\n",
|
||
" [\n",
|
||
" (\"features_preprocessing\", features_preprocessing),\n",
|
||
" (\"drop_columns\", drop_columns),\n",
|
||
" (\"model\", RandomForestRegressor()) # Выбор модели для обучения\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"# Использование конвейера\n",
|
||
"def train_pipeline(X, y):\n",
|
||
" pipeline.fit(X, y)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Формирование набора моделей для регрессии"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn.pipeline import make_pipeline\n",
|
||
"from sklearn.preprocessing import PolynomialFeatures\n",
|
||
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
|
||
"\n",
|
||
"random_state = 9\n",
|
||
"\n",
|
||
"models = {\n",
|
||
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
|
||
" \"linear_poly\": {\n",
|
||
" \"model\": make_pipeline(\n",
|
||
" PolynomialFeatures(degree=2),\n",
|
||
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
||
" )\n",
|
||
" },\n",
|
||
" \"linear_interact\": {\n",
|
||
" \"model\": make_pipeline(\n",
|
||
" PolynomialFeatures(interaction_only=True),\n",
|
||
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
|
||
" )\n",
|
||
" },\n",
|
||
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
|
||
" \"decision_tree\": {\n",
|
||
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
|
||
" },\n",
|
||
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
|
||
" \"random_forest\": {\n",
|
||
" \"model\": ensemble.RandomForestRegressor(\n",
|
||
" max_depth=7, random_state=random_state, n_jobs=-1\n",
|
||
" )\n",
|
||
" },\n",
|
||
" \"mlp\": {\n",
|
||
" \"model\": neural_network.MLPRegressor(\n",
|
||
" activation=\"tanh\",\n",
|
||
" hidden_layer_sizes=(3,),\n",
|
||
" max_iter=500,\n",
|
||
" early_stopping=True,\n",
|
||
" random_state=random_state,\n",
|
||
" )\n",
|
||
" },\n",
|
||
"}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Model: linear\n"
|
||
]
|
||
},
|
||
{
|
||
"ename": "ValueError",
|
||
"evalue": "could not convert string to float: '(2018 PR10)'",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)",
|
||
"Cell \u001b[1;32mIn[9], line 8\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m model_name \u001b[38;5;129;01min\u001b[39;00m models\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[0;32m 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModel: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m----> 8\u001b[0m fitted_model \u001b[38;5;241m=\u001b[39m \u001b[43mmodels\u001b[49m\u001b[43m[\u001b[49m\u001b[43mmodel_name\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodel\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mX_train\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mravel\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 11\u001b[0m y_train_pred \u001b[38;5;241m=\u001b[39m fitted_model\u001b[38;5;241m.\u001b[39mpredict(X_train\u001b[38;5;241m.\u001b[39mvalues)\n\u001b[0;32m 12\u001b[0m y_test_pred \u001b[38;5;241m=\u001b[39m fitted_model\u001b[38;5;241m.\u001b[39mpredict(X_test\u001b[38;5;241m.\u001b[39mvalues)\n",
|
||
"File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473\u001b[0m, in \u001b[0;36m_fit_context.<locals>.decorator.<locals>.wrapper\u001b[1;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1466\u001b[0m estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[0;32m 1468\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[0;32m 1469\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[0;32m 1470\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[0;32m 1471\u001b[0m )\n\u001b[0;32m 1472\u001b[0m ):\n\u001b[1;32m-> 1473\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||
"File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\linear_model\\_base.py:609\u001b[0m, in \u001b[0;36mLinearRegression.fit\u001b[1;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[0;32m 605\u001b[0m n_jobs_ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_jobs\n\u001b[0;32m 607\u001b[0m accept_sparse \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpositive \u001b[38;5;28;01melse\u001b[39;00m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcsr\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcsc\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcoo\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m--> 609\u001b[0m X, y \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_data\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 610\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 611\u001b[0m \u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 612\u001b[0m \u001b[43m \u001b[49m\u001b[43maccept_sparse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maccept_sparse\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 613\u001b[0m \u001b[43m \u001b[49m\u001b[43my_numeric\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 614\u001b[0m \u001b[43m \u001b[49m\u001b[43mmulti_output\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 615\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_writeable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 616\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 618\u001b[0m has_sw \u001b[38;5;241m=\u001b[39m sample_weight \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m 619\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_sw:\n",
|
||
"File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:650\u001b[0m, in \u001b[0;36mBaseEstimator._validate_data\u001b[1;34m(self, X, y, reset, validate_separately, cast_to_ndarray, **check_params)\u001b[0m\n\u001b[0;32m 648\u001b[0m y \u001b[38;5;241m=\u001b[39m check_array(y, input_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mcheck_y_params)\n\u001b[0;32m 649\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 650\u001b[0m X, y \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_X_y\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mcheck_params\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 651\u001b[0m out \u001b[38;5;241m=\u001b[39m X, y\n\u001b[0;32m 653\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m no_val_X \u001b[38;5;129;01mand\u001b[39;00m check_params\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mensure_2d\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m):\n",
|
||
"File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1301\u001b[0m, in \u001b[0;36mcheck_X_y\u001b[1;34m(X, y, accept_sparse, accept_large_sparse, dtype, order, copy, force_writeable, force_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, estimator)\u001b[0m\n\u001b[0;32m 1296\u001b[0m estimator_name \u001b[38;5;241m=\u001b[39m _check_estimator_name(estimator)\n\u001b[0;32m 1297\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 1298\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mestimator_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m requires y to be passed, but the target y is None\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1299\u001b[0m )\n\u001b[1;32m-> 1301\u001b[0m X \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_array\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1302\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1303\u001b[0m \u001b[43m \u001b[49m\u001b[43maccept_sparse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maccept_sparse\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1304\u001b[0m \u001b[43m \u001b[49m\u001b[43maccept_large_sparse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maccept_large_sparse\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1305\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1306\u001b[0m \u001b[43m \u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1307\u001b[0m \u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1308\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_writeable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_writeable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1309\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_all_finite\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_all_finite\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1310\u001b[0m \u001b[43m \u001b[49m\u001b[43mensure_2d\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensure_2d\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1311\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_nd\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mallow_nd\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1312\u001b[0m \u001b[43m \u001b[49m\u001b[43mensure_min_samples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensure_min_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1313\u001b[0m \u001b[43m \u001b[49m\u001b[43mensure_min_features\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensure_min_features\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1314\u001b[0m \u001b[43m \u001b[49m\u001b[43mestimator\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1315\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mX\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1316\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1318\u001b[0m y \u001b[38;5;241m=\u001b[39m _check_y(y, multi_output\u001b[38;5;241m=\u001b[39mmulti_output, y_numeric\u001b[38;5;241m=\u001b[39my_numeric, estimator\u001b[38;5;241m=\u001b[39mestimator)\n\u001b[0;32m 1320\u001b[0m check_consistent_length(X, y)\n",
|
||
"File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1012\u001b[0m, in \u001b[0;36mcheck_array\u001b[1;34m(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_writeable, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)\u001b[0m\n\u001b[0;32m 1010\u001b[0m array \u001b[38;5;241m=\u001b[39m xp\u001b[38;5;241m.\u001b[39mastype(array, dtype, copy\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m 1011\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1012\u001b[0m array \u001b[38;5;241m=\u001b[39m \u001b[43m_asarray_with_order\u001b[49m\u001b[43m(\u001b[49m\u001b[43marray\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mxp\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mxp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1013\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ComplexWarning \u001b[38;5;28;01mas\u001b[39;00m complex_warning:\n\u001b[0;32m 1014\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 1015\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mComplex data not supported\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(array)\n\u001b[0;32m 1016\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mcomplex_warning\u001b[39;00m\n",
|
||
"File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\utils\\_array_api.py:745\u001b[0m, in \u001b[0;36m_asarray_with_order\u001b[1;34m(array, dtype, order, copy, xp, device)\u001b[0m\n\u001b[0;32m 743\u001b[0m array \u001b[38;5;241m=\u001b[39m numpy\u001b[38;5;241m.\u001b[39marray(array, order\u001b[38;5;241m=\u001b[39morder, dtype\u001b[38;5;241m=\u001b[39mdtype)\n\u001b[0;32m 744\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 745\u001b[0m array \u001b[38;5;241m=\u001b[39m \u001b[43mnumpy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43masarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43marray\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 747\u001b[0m \u001b[38;5;66;03m# At this point array is a NumPy ndarray. We convert it to an array\u001b[39;00m\n\u001b[0;32m 748\u001b[0m \u001b[38;5;66;03m# container that is consistent with the input's namespace.\u001b[39;00m\n\u001b[0;32m 749\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m xp\u001b[38;5;241m.\u001b[39masarray(array)\n",
|
||
"\u001b[1;31mValueError\u001b[0m: could not convert string to float: '(2018 PR10)'"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import math\n",
|
||
"from pandas import DataFrame\n",
|
||
"from sklearn import metrics\n",
|
||
"\n",
|
||
"for model_name in models.keys():\n",
|
||
" print(f\"Model: {model_name}\")\n",
|
||
"\n",
|
||
" fitted_model = models[model_name][\"model\"].fit(\n",
|
||
" X_train.values, y_train.values.ravel()\n",
|
||
" )\n",
|
||
" y_train_pred = fitted_model.predict(X_train.values)\n",
|
||
" y_test_pred = fitted_model.predict(X_test.values)\n",
|
||
" models[model_name][\"fitted\"] = fitted_model\n",
|
||
" models[model_name][\"train_preds\"] = y_train_pred\n",
|
||
" models[model_name][\"preds\"] = y_test_pred\n",
|
||
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
|
||
" metrics.mean_squared_error(y_train, y_train_pred)\n",
|
||
" )\n",
|
||
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
|
||
" metrics.mean_squared_error(y_test, y_test_pred)\n",
|
||
" )\n",
|
||
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
|
||
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
|
||
" )\n",
|
||
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "aimenv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.12.6"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|