AIM-PIbd-31-Alekseev-I-S/Lab_4/Lab4.ipynb
Иван Алексеев 9320d0ab41 капец....x2
2024-11-09 12:24:47 +04:00

4115 lines
430 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Начало 4-й лабораторной\n",
"#### Ближайшие объекты к Земле"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['id', 'name', 'est_diameter_min', 'est_diameter_max',\n",
" 'relative_velocity', 'miss_distance', 'orbiting_body', 'sentry_object',\n",
" 'absolute_magnitude', 'hazardous'],\n",
" dtype='object')\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>name</th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>orbiting_body</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2162635</td>\n",
" <td>162635 (2000 SS164)</td>\n",
" <td>1.198271</td>\n",
" <td>2.679415</td>\n",
" <td>13569.249224</td>\n",
" <td>5.483974e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>16.73</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2277475</td>\n",
" <td>277475 (2005 WK4)</td>\n",
" <td>0.265800</td>\n",
" <td>0.594347</td>\n",
" <td>73588.726663</td>\n",
" <td>6.143813e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>20.00</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2512244</td>\n",
" <td>512244 (2015 YE18)</td>\n",
" <td>0.722030</td>\n",
" <td>1.614507</td>\n",
" <td>114258.692129</td>\n",
" <td>4.979872e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>17.83</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3596030</td>\n",
" <td>(2012 BV13)</td>\n",
" <td>0.096506</td>\n",
" <td>0.215794</td>\n",
" <td>24764.303138</td>\n",
" <td>2.543497e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>22.20</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3667127</td>\n",
" <td>(2014 GE35)</td>\n",
" <td>0.255009</td>\n",
" <td>0.570217</td>\n",
" <td>42737.733765</td>\n",
" <td>4.627557e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>20.09</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90831</th>\n",
" <td>3763337</td>\n",
" <td>(2016 VX1)</td>\n",
" <td>0.026580</td>\n",
" <td>0.059435</td>\n",
" <td>52078.886692</td>\n",
" <td>1.230039e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>25.00</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90832</th>\n",
" <td>3837603</td>\n",
" <td>(2019 AD3)</td>\n",
" <td>0.016771</td>\n",
" <td>0.037501</td>\n",
" <td>46114.605073</td>\n",
" <td>5.432121e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>26.00</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90833</th>\n",
" <td>54017201</td>\n",
" <td>(2020 JP3)</td>\n",
" <td>0.031956</td>\n",
" <td>0.071456</td>\n",
" <td>7566.807732</td>\n",
" <td>2.840077e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>24.60</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90834</th>\n",
" <td>54115824</td>\n",
" <td>(2021 CN5)</td>\n",
" <td>0.007321</td>\n",
" <td>0.016370</td>\n",
" <td>69199.154484</td>\n",
" <td>6.869206e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>27.80</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90835</th>\n",
" <td>54205447</td>\n",
" <td>(2021 TW7)</td>\n",
" <td>0.039862</td>\n",
" <td>0.089133</td>\n",
" <td>27024.455553</td>\n",
" <td>5.977213e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>24.12</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>90836 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" id name est_diameter_min est_diameter_max \\\n",
"0 2162635 162635 (2000 SS164) 1.198271 2.679415 \n",
"1 2277475 277475 (2005 WK4) 0.265800 0.594347 \n",
"2 2512244 512244 (2015 YE18) 0.722030 1.614507 \n",
"3 3596030 (2012 BV13) 0.096506 0.215794 \n",
"4 3667127 (2014 GE35) 0.255009 0.570217 \n",
"... ... ... ... ... \n",
"90831 3763337 (2016 VX1) 0.026580 0.059435 \n",
"90832 3837603 (2019 AD3) 0.016771 0.037501 \n",
"90833 54017201 (2020 JP3) 0.031956 0.071456 \n",
"90834 54115824 (2021 CN5) 0.007321 0.016370 \n",
"90835 54205447 (2021 TW7) 0.039862 0.089133 \n",
"\n",
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
"0 13569.249224 5.483974e+07 Earth False \n",
"1 73588.726663 6.143813e+07 Earth False \n",
"2 114258.692129 4.979872e+07 Earth False \n",
"3 24764.303138 2.543497e+07 Earth False \n",
"4 42737.733765 4.627557e+07 Earth False \n",
"... ... ... ... ... \n",
"90831 52078.886692 1.230039e+07 Earth False \n",
"90832 46114.605073 5.432121e+07 Earth False \n",
"90833 7566.807732 2.840077e+07 Earth False \n",
"90834 69199.154484 6.869206e+07 Earth False \n",
"90835 27024.455553 5.977213e+07 Earth False \n",
"\n",
" absolute_magnitude hazardous \n",
"0 16.73 False \n",
"1 20.00 True \n",
"2 17.83 False \n",
"3 22.20 False \n",
"4 20.09 True \n",
"... ... ... \n",
"90831 25.00 False \n",
"90832 26.00 False \n",
"90833 24.60 False \n",
"90834 27.80 False \n",
"90835 24.12 False \n",
"\n",
"[90836 rows x 10 columns]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"df = pd.read_csv(\".//static//csv//neo.csv\")\n",
"print(df.columns)\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Бизнес-цели:\n",
"\n",
"1. Идентификация потенциально опасных объектов\n",
"\n",
"Описание: классифицировать астероиды как потенциально опасные или безопасные (используя целевой признак \"hazardous\"). Эта задача актуальна для оценки рисков и подготовки соответствующих действий по защите Земли.\n",
"\n",
"2. Прогнозирование минимального расстояния до Земли\n",
"\n",
"Описание: предсказать минимальное расстояние до Земли для новых объектов на основе характеристик астероида (скорости, размера и других параметров). Это позволит планировать исследования и наблюдения в зависимости от опасности. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Определение достижимого уровня качества модели для первой задачи "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
"\n",
"Целевой признак -- hazardous"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>name</th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>orbiting_body</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2639</th>\n",
" <td>3634614</td>\n",
" <td>(2013 GT66)</td>\n",
" <td>0.024241</td>\n",
" <td>0.054205</td>\n",
" <td>43303.999094</td>\n",
" <td>4.814117e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>25.20</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29138</th>\n",
" <td>54143560</td>\n",
" <td>(2021 JU1)</td>\n",
" <td>0.030238</td>\n",
" <td>0.067615</td>\n",
" <td>21770.790211</td>\n",
" <td>5.646643e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>24.72</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>36927</th>\n",
" <td>3836085</td>\n",
" <td>(2018 VQ3)</td>\n",
" <td>0.201630</td>\n",
" <td>0.450858</td>\n",
" <td>109358.123029</td>\n",
" <td>6.435051e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>20.60</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61855</th>\n",
" <td>3769804</td>\n",
" <td>(2017 DJ34)</td>\n",
" <td>0.160160</td>\n",
" <td>0.358129</td>\n",
" <td>78494.609756</td>\n",
" <td>5.595780e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>21.10</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15916</th>\n",
" <td>3824978</td>\n",
" <td>(2018 KS)</td>\n",
" <td>0.006991</td>\n",
" <td>0.015633</td>\n",
" <td>19077.749486</td>\n",
" <td>3.834648e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>27.90</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29491</th>\n",
" <td>3827304</td>\n",
" <td>(2018 RR1)</td>\n",
" <td>0.002658</td>\n",
" <td>0.005943</td>\n",
" <td>19826.895880</td>\n",
" <td>3.852881e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>30.00</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18373</th>\n",
" <td>3735468</td>\n",
" <td>(2015 WY1)</td>\n",
" <td>0.103408</td>\n",
" <td>0.231228</td>\n",
" <td>82856.544926</td>\n",
" <td>7.314334e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>22.05</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25031</th>\n",
" <td>3802041</td>\n",
" <td>(2018 FE3)</td>\n",
" <td>0.009651</td>\n",
" <td>0.021579</td>\n",
" <td>34243.774201</td>\n",
" <td>4.257719e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>27.20</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35456</th>\n",
" <td>3430406</td>\n",
" <td>(2008 TR10)</td>\n",
" <td>0.221083</td>\n",
" <td>0.494356</td>\n",
" <td>19557.289783</td>\n",
" <td>2.152970e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>20.40</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14305</th>\n",
" <td>3285300</td>\n",
" <td>(2005 OG3)</td>\n",
" <td>0.298233</td>\n",
" <td>0.666868</td>\n",
" <td>20309.404706</td>\n",
" <td>1.770015e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>19.75</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>72668 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" id name est_diameter_min est_diameter_max \\\n",
"2639 3634614 (2013 GT66) 0.024241 0.054205 \n",
"29138 54143560 (2021 JU1) 0.030238 0.067615 \n",
"36927 3836085 (2018 VQ3) 0.201630 0.450858 \n",
"61855 3769804 (2017 DJ34) 0.160160 0.358129 \n",
"15916 3824978 (2018 KS) 0.006991 0.015633 \n",
"... ... ... ... ... \n",
"29491 3827304 (2018 RR1) 0.002658 0.005943 \n",
"18373 3735468 (2015 WY1) 0.103408 0.231228 \n",
"25031 3802041 (2018 FE3) 0.009651 0.021579 \n",
"35456 3430406 (2008 TR10) 0.221083 0.494356 \n",
"14305 3285300 (2005 OG3) 0.298233 0.666868 \n",
"\n",
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
"2639 43303.999094 4.814117e+07 Earth False \n",
"29138 21770.790211 5.646643e+07 Earth False \n",
"36927 109358.123029 6.435051e+07 Earth False \n",
"61855 78494.609756 5.595780e+07 Earth False \n",
"15916 19077.749486 3.834648e+07 Earth False \n",
"... ... ... ... ... \n",
"29491 19826.895880 3.852881e+07 Earth False \n",
"18373 82856.544926 7.314334e+07 Earth False \n",
"25031 34243.774201 4.257719e+07 Earth False \n",
"35456 19557.289783 2.152970e+07 Earth False \n",
"14305 20309.404706 1.770015e+07 Earth False \n",
"\n",
" absolute_magnitude hazardous \n",
"2639 25.20 False \n",
"29138 24.72 False \n",
"36927 20.60 False \n",
"61855 21.10 False \n",
"15916 27.90 False \n",
"... ... ... \n",
"29491 30.00 False \n",
"18373 22.05 False \n",
"25031 27.20 False \n",
"35456 20.40 False \n",
"14305 19.75 False \n",
"\n",
"[72668 rows x 10 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2639</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29138</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>36927</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61855</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15916</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29491</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18373</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25031</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35456</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14305</th>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>72668 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" hazardous\n",
"2639 False\n",
"29138 False\n",
"36927 False\n",
"61855 False\n",
"15916 False\n",
"... ...\n",
"29491 False\n",
"18373 False\n",
"25031 False\n",
"35456 False\n",
"14305 False\n",
"\n",
"[72668 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>name</th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>orbiting_body</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>9040</th>\n",
" <td>2474532</td>\n",
" <td>474532 (2003 VG1)</td>\n",
" <td>0.472667</td>\n",
" <td>1.056915</td>\n",
" <td>21779.237137</td>\n",
" <td>3.443050e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>18.75</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>67305</th>\n",
" <td>3774018</td>\n",
" <td>(2017 HF1)</td>\n",
" <td>0.084053</td>\n",
" <td>0.187949</td>\n",
" <td>53291.016226</td>\n",
" <td>6.862591e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>22.50</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>77741</th>\n",
" <td>54269585</td>\n",
" <td>(2022 GQ2)</td>\n",
" <td>0.018220</td>\n",
" <td>0.040742</td>\n",
" <td>43089.046433</td>\n",
" <td>2.592726e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>25.82</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>81520</th>\n",
" <td>54097970</td>\n",
" <td>(2020 XS)</td>\n",
" <td>0.152952</td>\n",
" <td>0.342011</td>\n",
" <td>93246.455599</td>\n",
" <td>4.709054e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>21.20</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>508</th>\n",
" <td>3730802</td>\n",
" <td>(2015 TT238)</td>\n",
" <td>0.031956</td>\n",
" <td>0.071456</td>\n",
" <td>37708.258544</td>\n",
" <td>4.232149e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>24.60</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28261</th>\n",
" <td>3532365</td>\n",
" <td>(2010 MH1)</td>\n",
" <td>0.139494</td>\n",
" <td>0.311918</td>\n",
" <td>37604.980238</td>\n",
" <td>7.369507e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>21.40</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1159</th>\n",
" <td>54073345</td>\n",
" <td>(2020 UE)</td>\n",
" <td>0.020728</td>\n",
" <td>0.046349</td>\n",
" <td>36720.077728</td>\n",
" <td>3.366114e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>25.54</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48095</th>\n",
" <td>3836195</td>\n",
" <td>(2018 VT7)</td>\n",
" <td>0.006991</td>\n",
" <td>0.015633</td>\n",
" <td>7616.496535</td>\n",
" <td>6.376350e+06</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>27.90</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90234</th>\n",
" <td>3752902</td>\n",
" <td>(2016 JG12)</td>\n",
" <td>0.084053</td>\n",
" <td>0.187949</td>\n",
" <td>21894.554692</td>\n",
" <td>5.736984e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>22.50</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12013</th>\n",
" <td>3445077</td>\n",
" <td>(2009 BM58)</td>\n",
" <td>0.038420</td>\n",
" <td>0.085909</td>\n",
" <td>49828.611609</td>\n",
" <td>4.305599e+07</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>24.20</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>18168 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" id name est_diameter_min est_diameter_max \\\n",
"9040 2474532 474532 (2003 VG1) 0.472667 1.056915 \n",
"67305 3774018 (2017 HF1) 0.084053 0.187949 \n",
"77741 54269585 (2022 GQ2) 0.018220 0.040742 \n",
"81520 54097970 (2020 XS) 0.152952 0.342011 \n",
"508 3730802 (2015 TT238) 0.031956 0.071456 \n",
"... ... ... ... ... \n",
"28261 3532365 (2010 MH1) 0.139494 0.311918 \n",
"1159 54073345 (2020 UE) 0.020728 0.046349 \n",
"48095 3836195 (2018 VT7) 0.006991 0.015633 \n",
"90234 3752902 (2016 JG12) 0.084053 0.187949 \n",
"12013 3445077 (2009 BM58) 0.038420 0.085909 \n",
"\n",
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
"9040 21779.237137 3.443050e+07 Earth False \n",
"67305 53291.016226 6.862591e+07 Earth False \n",
"77741 43089.046433 2.592726e+07 Earth False \n",
"81520 93246.455599 4.709054e+07 Earth False \n",
"508 37708.258544 4.232149e+07 Earth False \n",
"... ... ... ... ... \n",
"28261 37604.980238 7.369507e+07 Earth False \n",
"1159 36720.077728 3.366114e+07 Earth False \n",
"48095 7616.496535 6.376350e+06 Earth False \n",
"90234 21894.554692 5.736984e+07 Earth False \n",
"12013 49828.611609 4.305599e+07 Earth False \n",
"\n",
" absolute_magnitude hazardous \n",
"9040 18.75 False \n",
"67305 22.50 False \n",
"77741 25.82 False \n",
"81520 21.20 False \n",
"508 24.60 False \n",
"... ... ... \n",
"28261 21.40 False \n",
"1159 25.54 False \n",
"48095 27.90 False \n",
"90234 22.50 False \n",
"12013 24.20 False \n",
"\n",
"[18168 rows x 10 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>9040</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>67305</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>77741</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>81520</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>508</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28261</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1159</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48095</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90234</th>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12013</th>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>18168 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" hazardous\n",
"9040 False\n",
"67305 False\n",
"77741 False\n",
"81520 False\n",
"508 False\n",
"... ...\n",
"28261 False\n",
"1159 False\n",
"48095 False\n",
"90234 False\n",
"12013 False\n",
"\n",
"[18168 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Устанавливаем случайное состояние\n",
"random_state = 42\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" X = df_input # Contains all columns.\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
" # Split original dataframe into train and temp dataframes.\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
" # Split the temp dataframe into val and test dataframes.\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"hazardous\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование конвейера для классификации данных\n",
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
"\n",
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
"\n",
"features_preprocessing -- трансформер для предобработки признаков\n",
"\n",
"features_engineering -- трансформер для конструирования признаков\n",
"\n",
"drop_columns -- трансформер для удаления колонок\n",
"\n",
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from sklearn.ensemble import RandomForestRegressor # Пример регрессионной модели\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.pipeline import make_pipeline\n",
"\n",
"class StarbucksFeatures(BaseEstimator, TransformerMixin):\n",
" def __init__(self):\n",
" pass\n",
" \n",
" def fit(self, X, y=None):\n",
" return self\n",
"\n",
" def transform(self, X, y=None):\n",
" # Преобразование категориальных столбцов в числовые 1/0\n",
" X[\"hazardous\"] = X[\"hazardous\"].astype(int)\n",
" X[\"sentry_object\"] = X[\"sentry_object\"].astype(int)\n",
" X[\"Length_to_Width_Ratio\"] = X[\"x\"] / X[\"y\"]\n",
" return X\n",
"\n",
" def get_feature_names_out(self, features_in):\n",
" return np.append(features_in, [\"Length_to_Width_Ratio\"], axis=0)\n",
"\n",
"# Указываем столбцы, которые нужно удалить и обрабатывать\n",
"columns_to_drop = [\"name\", \"orbiting_body\"]\n",
"num_columns = [\"est_diameter_min\", \"est_diameter_max\",\n",
" \"relative_velocity\", \"miss_distance\", \"sentry_object\",\n",
" \"absolute_magnitude\", \"hazardous\"]\n",
"cat_columns = [\"sentry_object\", \"hazardous\"]\n",
" \n",
"\n",
"# Определяем предобработку для численных данных\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"# Определяем предобработку для категориальных данных\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"# Подготовка признаков с использованием ColumnTransformer\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"# Удаление нежелательных столбцов\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"# Постобработка признаков\n",
"features_postprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"preprocessing_cat\", preprocessing_cat, [\"Cabin_type\"]), \n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"# Создание окончательного конвейера\n",
"pipeline = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" (\"model\", RandomForestRegressor()) # Выбор модели для обучения\n",
" ]\n",
")\n",
"\n",
"# Использование конвейера\n",
"def train_pipeline(X, y):\n",
" pipeline.fit(X, y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Демонстрация работы конвейера"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" <th>id</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2639</th>\n",
" <td>-0.331616</td>\n",
" <td>-0.331616</td>\n",
" <td>-0.188160</td>\n",
" <td>0.494297</td>\n",
" <td>0.0</td>\n",
" <td>0.577785</td>\n",
" <td>-0.328347</td>\n",
" <td>3634614</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29138</th>\n",
" <td>-0.312486</td>\n",
" <td>-0.312486</td>\n",
" <td>-1.040729</td>\n",
" <td>0.866716</td>\n",
" <td>0.0</td>\n",
" <td>0.412170</td>\n",
" <td>-0.328347</td>\n",
" <td>54143560</td>\n",
" </tr>\n",
" <tr>\n",
" <th>36927</th>\n",
" <td>0.234246</td>\n",
" <td>0.234246</td>\n",
" <td>2.427134</td>\n",
" <td>1.219399</td>\n",
" <td>0.0</td>\n",
" <td>-1.009355</td>\n",
" <td>-0.328347</td>\n",
" <td>3836085</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61855</th>\n",
" <td>0.101960</td>\n",
" <td>0.101960</td>\n",
" <td>1.205148</td>\n",
" <td>0.843963</td>\n",
" <td>0.0</td>\n",
" <td>-0.836840</td>\n",
" <td>-0.328347</td>\n",
" <td>3769804</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15916</th>\n",
" <td>-0.386643</td>\n",
" <td>-0.386643</td>\n",
" <td>-1.147355</td>\n",
" <td>0.056145</td>\n",
" <td>0.0</td>\n",
" <td>1.509367</td>\n",
" <td>-0.328347</td>\n",
" <td>3824978</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29491</th>\n",
" <td>-0.400466</td>\n",
" <td>-0.400466</td>\n",
" <td>-1.117694</td>\n",
" <td>0.064301</td>\n",
" <td>0.0</td>\n",
" <td>2.233931</td>\n",
" <td>-0.328347</td>\n",
" <td>3827304</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18373</th>\n",
" <td>-0.079077</td>\n",
" <td>-0.079077</td>\n",
" <td>1.377851</td>\n",
" <td>1.612734</td>\n",
" <td>0.0</td>\n",
" <td>-0.509061</td>\n",
" <td>-0.328347</td>\n",
" <td>3735468</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25031</th>\n",
" <td>-0.378159</td>\n",
" <td>-0.378159</td>\n",
" <td>-0.546884</td>\n",
" <td>0.245400</td>\n",
" <td>0.0</td>\n",
" <td>1.267846</td>\n",
" <td>-0.328347</td>\n",
" <td>3802041</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35456</th>\n",
" <td>0.296300</td>\n",
" <td>0.296300</td>\n",
" <td>-1.128369</td>\n",
" <td>-0.696130</td>\n",
" <td>0.0</td>\n",
" <td>-1.078361</td>\n",
" <td>-0.328347</td>\n",
" <td>3430406</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14305</th>\n",
" <td>0.542404</td>\n",
" <td>0.542404</td>\n",
" <td>-1.098590</td>\n",
" <td>-0.867440</td>\n",
" <td>0.0</td>\n",
" <td>-1.302631</td>\n",
" <td>-0.328347</td>\n",
" <td>3285300</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>72668 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" est_diameter_min est_diameter_max relative_velocity miss_distance \\\n",
"2639 -0.331616 -0.331616 -0.188160 0.494297 \n",
"29138 -0.312486 -0.312486 -1.040729 0.866716 \n",
"36927 0.234246 0.234246 2.427134 1.219399 \n",
"61855 0.101960 0.101960 1.205148 0.843963 \n",
"15916 -0.386643 -0.386643 -1.147355 0.056145 \n",
"... ... ... ... ... \n",
"29491 -0.400466 -0.400466 -1.117694 0.064301 \n",
"18373 -0.079077 -0.079077 1.377851 1.612734 \n",
"25031 -0.378159 -0.378159 -0.546884 0.245400 \n",
"35456 0.296300 0.296300 -1.128369 -0.696130 \n",
"14305 0.542404 0.542404 -1.098590 -0.867440 \n",
"\n",
" sentry_object absolute_magnitude hazardous id \n",
"2639 0.0 0.577785 -0.328347 3634614 \n",
"29138 0.0 0.412170 -0.328347 54143560 \n",
"36927 0.0 -1.009355 -0.328347 3836085 \n",
"61855 0.0 -0.836840 -0.328347 3769804 \n",
"15916 0.0 1.509367 -0.328347 3824978 \n",
"... ... ... ... ... \n",
"29491 0.0 2.233931 -0.328347 3827304 \n",
"18373 0.0 -0.509061 -0.328347 3735468 \n",
"25031 0.0 1.267846 -0.328347 3802041 \n",
"35456 0.0 -1.078361 -0.328347 3430406 \n",
"14305 0.0 -1.302631 -0.328347 3285300 \n",
"\n",
"[72668 rows x 8 columns]"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование набора моделей для классификации\n",
" logistic -- логистическая регрессия\n",
"\n",
"ridge -- гребневая регрессия\n",
"\n",
"decision_tree -- дерево решений\n",
"\n",
"knn -- k-ближайших соседей\n",
"\n",
"naive_bayes -- наивный Байесовский классификатор\n",
"\n",
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
"\n",
"random_forest -- метод случайного леса (набор деревьев решений)\n",
"\n",
"mlp -- многослойный персептрон (нейронная сеть)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение моделей на обучающем наборе данных и оценка на тестовом"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Сводная таблица оценок качества для использованных моделей классификации"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"hazardous\", \"safe\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"16400 - количество истинных положительных диагнозов (True Positives), где модель правильно определила объекты как \"hazardous\".\n",
"\n",
"1768 в некоторых моделях - количество ложных отрицательных диагнозов (False Negatives), где модель неправильно определила объекты, которые на самом деле принадлежат к классу \"hazardous\", но были отнесены к классу \"safe\". \n",
"\n",
"Исходя из значений True Positives и False Negatives, можно сказать, что модель имеет высокую точность при предсказании класса \"hazardous\". В принципе, уровень ложных отрицательных результатов в некоторых моделях (1768) говорит нам о том, что существует некотрое небольшое количество примеров, которые модель пропускает.\n",
"\n",
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_a22cf_row0_col0, #T_a22cf_row0_col1, #T_a22cf_row0_col2, #T_a22cf_row0_col3, #T_a22cf_row1_col0, #T_a22cf_row1_col1, #T_a22cf_row1_col2, #T_a22cf_row1_col3, #T_a22cf_row2_col0, #T_a22cf_row2_col1, #T_a22cf_row2_col2, #T_a22cf_row2_col3, #T_a22cf_row3_col0, #T_a22cf_row3_col1, #T_a22cf_row3_col2, #T_a22cf_row3_col3, #T_a22cf_row7_col2, #T_a22cf_row7_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_a22cf_row0_col4, #T_a22cf_row0_col5, #T_a22cf_row0_col6, #T_a22cf_row0_col7, #T_a22cf_row1_col4, #T_a22cf_row1_col5, #T_a22cf_row1_col6, #T_a22cf_row1_col7, #T_a22cf_row2_col4, #T_a22cf_row2_col5, #T_a22cf_row2_col6, #T_a22cf_row2_col7, #T_a22cf_row3_col4, #T_a22cf_row3_col5, #T_a22cf_row3_col6, #T_a22cf_row3_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a22cf_row4_col0 {\n",
" background-color: #86d549;\n",
" color: #000000;\n",
"}\n",
"#T_a22cf_row4_col1 {\n",
" background-color: #77d153;\n",
" color: #000000;\n",
"}\n",
"#T_a22cf_row4_col2 {\n",
" background-color: #63cb5f;\n",
" color: #000000;\n",
"}\n",
"#T_a22cf_row4_col3 {\n",
" background-color: #4ac16d;\n",
" color: #000000;\n",
"}\n",
"#T_a22cf_row4_col4 {\n",
" background-color: #c03a83;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a22cf_row4_col5 {\n",
" background-color: #b32c8e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a22cf_row4_col6 {\n",
" background-color: #c7427c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a22cf_row4_col7 {\n",
" background-color: #bd3786;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a22cf_row5_col0, #T_a22cf_row5_col1, #T_a22cf_row5_col2, #T_a22cf_row5_col3, #T_a22cf_row6_col0, #T_a22cf_row6_col1, #T_a22cf_row6_col2, #T_a22cf_row6_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a22cf_row5_col4, #T_a22cf_row6_col4 {\n",
" background-color: #8004a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a22cf_row5_col5, #T_a22cf_row6_col5 {\n",
" background-color: #7d03a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a22cf_row5_col6, #T_a22cf_row5_col7, #T_a22cf_row6_col6, #T_a22cf_row6_col7, #T_a22cf_row7_col4, #T_a22cf_row7_col5 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a22cf_row7_col0 {\n",
" background-color: #25ac82;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a22cf_row7_col1 {\n",
" background-color: #26ad81;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a22cf_row7_col6 {\n",
" background-color: #ac2694;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a22cf_row7_col7 {\n",
" background-color: #ad2793;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_a22cf\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_a22cf_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_a22cf_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_a22cf_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_a22cf_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_a22cf_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_a22cf_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_a22cf_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_a22cf_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_a22cf_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_a22cf_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_a22cf_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_a22cf_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_a22cf_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_a22cf_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_a22cf_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_a22cf_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_a22cf_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a22cf_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
" <td id=\"T_a22cf_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_a22cf_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_a22cf_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_a22cf_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_a22cf_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_a22cf_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_a22cf_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_a22cf_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a22cf_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
" <td id=\"T_a22cf_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_a22cf_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_a22cf_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_a22cf_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_a22cf_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" <td id=\"T_a22cf_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
" <td id=\"T_a22cf_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
" <td id=\"T_a22cf_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a22cf_level0_row3\" class=\"row_heading level0 row3\" >gradient_boosting</th>\n",
" <td id=\"T_a22cf_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_a22cf_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_a22cf_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_a22cf_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_a22cf_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" <td id=\"T_a22cf_row3_col5\" class=\"data row3 col5\" >1.000000</td>\n",
" <td id=\"T_a22cf_row3_col6\" class=\"data row3 col6\" >1.000000</td>\n",
" <td id=\"T_a22cf_row3_col7\" class=\"data row3 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a22cf_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
" <td id=\"T_a22cf_row4_col0\" class=\"data row4 col0\" >0.884596</td>\n",
" <td id=\"T_a22cf_row4_col1\" class=\"data row4 col1\" >0.826374</td>\n",
" <td id=\"T_a22cf_row4_col2\" class=\"data row4 col2\" >0.744627</td>\n",
" <td id=\"T_a22cf_row4_col3\" class=\"data row4 col3\" >0.638009</td>\n",
" <td id=\"T_a22cf_row4_col4\" class=\"data row4 col4\" >0.965693</td>\n",
" <td id=\"T_a22cf_row4_col5\" class=\"data row4 col5\" >0.951728</td>\n",
" <td id=\"T_a22cf_row4_col6\" class=\"data row4 col6\" >0.808599</td>\n",
" <td id=\"T_a22cf_row4_col7\" class=\"data row4 col7\" >0.720077</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a22cf_level0_row5\" class=\"row_heading level0 row5\" >naive_bayes</th>\n",
" <td id=\"T_a22cf_row5_col0\" class=\"data row5 col0\" >0.000000</td>\n",
" <td id=\"T_a22cf_row5_col1\" class=\"data row5 col1\" >0.000000</td>\n",
" <td id=\"T_a22cf_row5_col2\" class=\"data row5 col2\" >0.000000</td>\n",
" <td id=\"T_a22cf_row5_col3\" class=\"data row5 col3\" >0.000000</td>\n",
" <td id=\"T_a22cf_row5_col4\" class=\"data row5 col4\" >0.902681</td>\n",
" <td id=\"T_a22cf_row5_col5\" class=\"data row5 col5\" >0.902686</td>\n",
" <td id=\"T_a22cf_row5_col6\" class=\"data row5 col6\" >0.000000</td>\n",
" <td id=\"T_a22cf_row5_col7\" class=\"data row5 col7\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a22cf_level0_row6\" class=\"row_heading level0 row6\" >mlp</th>\n",
" <td id=\"T_a22cf_row6_col0\" class=\"data row6 col0\" >0.000000</td>\n",
" <td id=\"T_a22cf_row6_col1\" class=\"data row6 col1\" >0.000000</td>\n",
" <td id=\"T_a22cf_row6_col2\" class=\"data row6 col2\" >0.000000</td>\n",
" <td id=\"T_a22cf_row6_col3\" class=\"data row6 col3\" >0.000000</td>\n",
" <td id=\"T_a22cf_row6_col4\" class=\"data row6 col4\" >0.902681</td>\n",
" <td id=\"T_a22cf_row6_col5\" class=\"data row6 col5\" >0.902686</td>\n",
" <td id=\"T_a22cf_row6_col6\" class=\"data row6 col6\" >0.000000</td>\n",
" <td id=\"T_a22cf_row6_col7\" class=\"data row6 col7\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a22cf_level0_row7\" class=\"row_heading level0 row7\" >ridge</th>\n",
" <td id=\"T_a22cf_row7_col0\" class=\"data row7 col0\" >0.415780</td>\n",
" <td id=\"T_a22cf_row7_col1\" class=\"data row7 col1\" >0.421253</td>\n",
" <td id=\"T_a22cf_row7_col2\" class=\"data row7 col2\" >1.000000</td>\n",
" <td id=\"T_a22cf_row7_col3\" class=\"data row7 col3\" >1.000000</td>\n",
" <td id=\"T_a22cf_row7_col4\" class=\"data row7 col4\" >0.863255</td>\n",
" <td id=\"T_a22cf_row7_col5\" class=\"data row7 col5\" >0.866303</td>\n",
" <td id=\"T_a22cf_row7_col6\" class=\"data row7 col6\" >0.587351</td>\n",
" <td id=\"T_a22cf_row7_col7\" class=\"data row7 col7\" >0.592791</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1b3e1e74950>"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Все модели в данной выборке, а именно логистическая регрессия, ридж-регрессия, дерево решений, KNN, наивный байесовский классификатор, градиентный бустинг, случайный лес и многослойный перцептрон (MLP) демонстрируют неплохие значения по всем метрикам на обучающих и тестовых наборах данных.\n",
"\n",
"Модели Naive Bayes и MLP не так эффективны по сравнению с другими, но в некоторых метриках показывают высокие результаты. \n",
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_46430_row0_col0, #T_46430_row0_col1, #T_46430_row1_col0, #T_46430_row1_col1, #T_46430_row2_col0, #T_46430_row2_col1, #T_46430_row3_col0, #T_46430_row3_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_46430_row0_col2, #T_46430_row0_col3, #T_46430_row0_col4, #T_46430_row1_col2, #T_46430_row1_col3, #T_46430_row1_col4, #T_46430_row2_col2, #T_46430_row2_col3, #T_46430_row2_col4, #T_46430_row3_col2, #T_46430_row3_col3, #T_46430_row3_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_46430_row4_col0, #T_46430_row6_col1, #T_46430_row7_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_46430_row4_col1 {\n",
" background-color: #40bd72;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_46430_row4_col2 {\n",
" background-color: #d9586a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_46430_row4_col3, #T_46430_row6_col2 {\n",
" background-color: #a51f99;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_46430_row4_col4 {\n",
" background-color: #ae2892;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_46430_row5_col0 {\n",
" background-color: #4ac16d;\n",
" color: #000000;\n",
"}\n",
"#T_46430_row5_col1 {\n",
" background-color: #5cc863;\n",
" color: #000000;\n",
"}\n",
"#T_46430_row5_col2 {\n",
" background-color: #d14e72;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_46430_row5_col3 {\n",
" background-color: #ba3388;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_46430_row5_col4 {\n",
" background-color: #bb3488;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_46430_row6_col0, #T_46430_row7_col0 {\n",
" background-color: #1e9d89;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_46430_row6_col3, #T_46430_row6_col4, #T_46430_row7_col2, #T_46430_row7_col3, #T_46430_row7_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_46430\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_46430_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_46430_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_46430_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_46430_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_46430_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_46430_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_46430_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_46430_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_46430_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_46430_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_46430_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_46430_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
" <td id=\"T_46430_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_46430_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_46430_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_46430_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_46430_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_46430_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
" <td id=\"T_46430_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_46430_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_46430_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_46430_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_46430_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_46430_level0_row3\" class=\"row_heading level0 row3\" >gradient_boosting</th>\n",
" <td id=\"T_46430_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_46430_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_46430_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_46430_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_46430_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_46430_level0_row4\" class=\"row_heading level0 row4\" >ridge</th>\n",
" <td id=\"T_46430_row4_col0\" class=\"data row4 col0\" >0.866303</td>\n",
" <td id=\"T_46430_row4_col1\" class=\"data row4 col1\" >0.592791</td>\n",
" <td id=\"T_46430_row4_col2\" class=\"data row4 col2\" >0.995675</td>\n",
" <td id=\"T_46430_row4_col3\" class=\"data row4 col3\" >0.528180</td>\n",
" <td id=\"T_46430_row4_col4\" class=\"data row4 col4\" >0.599051</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_46430_level0_row5\" class=\"row_heading level0 row5\" >knn</th>\n",
" <td id=\"T_46430_row5_col0\" class=\"data row5 col0\" >0.951728</td>\n",
" <td id=\"T_46430_row5_col1\" class=\"data row5 col1\" >0.720077</td>\n",
" <td id=\"T_46430_row5_col2\" class=\"data row5 col2\" >0.953405</td>\n",
" <td id=\"T_46430_row5_col3\" class=\"data row5 col3\" >0.694141</td>\n",
" <td id=\"T_46430_row5_col4\" class=\"data row5 col4\" >0.701100</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_46430_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
" <td id=\"T_46430_row6_col0\" class=\"data row6 col0\" >0.902686</td>\n",
" <td id=\"T_46430_row6_col1\" class=\"data row6 col1\" >0.000000</td>\n",
" <td id=\"T_46430_row6_col2\" class=\"data row6 col2\" >0.766341</td>\n",
" <td id=\"T_46430_row6_col3\" class=\"data row6 col3\" >0.000000</td>\n",
" <td id=\"T_46430_row6_col4\" class=\"data row6 col4\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_46430_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_46430_row7_col0\" class=\"data row7 col0\" >0.902686</td>\n",
" <td id=\"T_46430_row7_col1\" class=\"data row7 col1\" >0.000000</td>\n",
" <td id=\"T_46430_row7_col2\" class=\"data row7 col2\" >0.500000</td>\n",
" <td id=\"T_46430_row7_col3\" class=\"data row7 col3\" >0.000000</td>\n",
" <td id=\"T_46430_row7_col4\" class=\"data row7 col4\" >0.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1b3dda0e660>"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Схожий вывод можно сделать и для следующих метрик: Accuracy, F1, ROC AUC, Cohen's Kappa и MCC. Все модели, кроме Naive Bayes и MLP, указывают на хорошо-развитую способность к выделению классов"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 0'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>Predicted</th>\n",
" <th>name</th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>orbiting_body</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [id, Predicted, name, est_diameter_min, est_diameter_max, relative_velocity, miss_distance, orbiting_body, sentry_object, absolute_magnitude, hazardous]\n",
"Index: []"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"hazardous\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Пример использования обученной модели (конвейера) для предсказания\n"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>name</th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>orbiting_body</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>67305</th>\n",
" <td>3774018</td>\n",
" <td>(2017 HF1)</td>\n",
" <td>0.084053</td>\n",
" <td>0.187949</td>\n",
" <td>53291.016226</td>\n",
" <td>68625911.198806</td>\n",
" <td>Earth</td>\n",
" <td>False</td>\n",
" <td>22.5</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id name est_diameter_min est_diameter_max \\\n",
"67305 3774018 (2017 HF1) 0.084053 0.187949 \n",
"\n",
" relative_velocity miss_distance orbiting_body sentry_object \\\n",
"67305 53291.016226 68625911.198806 Earth False \n",
"\n",
" absolute_magnitude hazardous \n",
"67305 22.5 False "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>sentry_object</th>\n",
" <th>absolute_magnitude</th>\n",
" <th>hazardous</th>\n",
" <th>id</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>67305</th>\n",
" <td>-0.140818</td>\n",
" <td>-0.140818</td>\n",
" <td>0.207258</td>\n",
" <td>1.410653</td>\n",
" <td>0.0</td>\n",
" <td>-0.353797</td>\n",
" <td>-0.328347</td>\n",
" <td>3774018.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" est_diameter_min est_diameter_max relative_velocity miss_distance \\\n",
"67305 -0.140818 -0.140818 0.207258 1.410653 \n",
"\n",
" sentry_object absolute_magnitude hazardous id \n",
"67305 0.0 -0.353797 -0.328347 3774018.0 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: False (proba: [9.99855425e-01 1.44575476e-04])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 0'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 67305\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Подбор гиперпараметров методом поиска по сетке "
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 5,\n",
" 'model__max_features': 'sqrt',\n",
" 'model__n_estimators': 50}"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"optimized_model_type = \"random_forest\"\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 50, 100],\n",
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
" \"model__max_depth\": [5, 7, 10],\n",
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"import pandas as pd\n",
"\n",
"\n",
"# Определяем числовые признаки\n",
"numeric_features = X_train.select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
"\n",
"# Установка random_state\n",
"random_state = 42\n",
"\n",
"# Определение трансформера\n",
"pipeline_end = ColumnTransformer([\n",
" ('numeric', StandardScaler(), numeric_features),\n",
" # Добавьте другие трансформеры, если требуется\n",
"])\n",
"\n",
"# Объявление модели\n",
"optimized_model = RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"sqrt\",\n",
" n_estimators=50,\n",
")\n",
"\n",
"# Создание пайплайна с корректными шагами\n",
"result = {}\n",
"\n",
"# Обучение модели\n",
"result[\"pipeline\"] = Pipeline([\n",
" (\"pipeline\", pipeline_end),\n",
" (\"model\", optimized_model)\n",
"]).fit(X_train, y_train.values.ravel())\n",
"\n",
"# Прогнозирование и расчет метрик\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"# Метрики для оценки модели\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_55496_row0_col0, #T_55496_row0_col1, #T_55496_row0_col2, #T_55496_row0_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_55496_row0_col4, #T_55496_row0_col5, #T_55496_row0_col6, #T_55496_row0_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_55496_row1_col0, #T_55496_row1_col1, #T_55496_row1_col2, #T_55496_row1_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_55496_row1_col4, #T_55496_row1_col5, #T_55496_row1_col6, #T_55496_row1_col7 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_55496\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_55496_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_55496_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_55496_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_55496_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_55496_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_55496_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_55496_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_55496_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" <th class=\"blank col7\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_55496_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_55496_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_55496_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_55496_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_55496_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_55496_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_55496_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_55496_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_55496_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_55496_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_55496_row1_col0\" class=\"data row1 col0\" >0.833191</td>\n",
" <td id=\"T_55496_row1_col1\" class=\"data row1 col1\" >0.862500</td>\n",
" <td id=\"T_55496_row1_col2\" class=\"data row1 col2\" >0.138433</td>\n",
" <td id=\"T_55496_row1_col3\" class=\"data row1 col3\" >0.156109</td>\n",
" <td id=\"T_55496_row1_col4\" class=\"data row1 col4\" >0.913456</td>\n",
" <td id=\"T_55496_row1_col5\" class=\"data row1 col5\" >0.915456</td>\n",
" <td id=\"T_55496_row1_col6\" class=\"data row1 col6\" >0.237420</td>\n",
" <td id=\"T_55496_row1_col7\" class=\"data row1 col7\" >0.264368</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1b3e1be0920>"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_36483_row0_col0, #T_36483_row0_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_36483_row0_col2, #T_36483_row0_col3, #T_36483_row0_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_36483_row1_col0, #T_36483_row1_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_36483_row1_col2, #T_36483_row1_col3, #T_36483_row1_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_36483\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_36483_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_36483_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_36483_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_36483_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_36483_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_36483_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_36483_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_36483_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_36483_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_36483_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_36483_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_36483_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_36483_row1_col0\" class=\"data row1 col0\" >0.915456</td>\n",
" <td id=\"T_36483_row1_col1\" class=\"data row1 col1\" >0.264368</td>\n",
" <td id=\"T_36483_row1_col2\" class=\"data row1 col2\" >0.927493</td>\n",
" <td id=\"T_36483_row1_col3\" class=\"data row1 col3\" >0.241751</td>\n",
" <td id=\"T_36483_row1_col4\" class=\"data row1 col4\" >0.345694</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1b3e1be2de0>"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"hazardous\", \"safe\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В желтых квадрате мы наблюдаем значение 16400, что обозначает количество правильно классифицированных объектов, отнесенных к классу \"hazardsous\". Это свидетельствует о том, что модель успешно идентифицирует объекты этого класса, минимизируя количество ложных положительных срабатываний.\n",
"\n",
"В фиолетвом квадрате значение 276 указывает на количество правильно классифицированных объектов, отнесенных к классу \"More\". Это является показателем не такой высокой точности модели в определении объектов данного класса."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Определение достижимого уровня качества модели для второй задачи (задача регрессии)"
]
},
{
"cell_type": "code",
"execution_count": 201,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(5000, 6)\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>absolute_magnitude</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>3943344</td>\n",
" <td>0.024241</td>\n",
" <td>0.054205</td>\n",
" <td>22148.962596</td>\n",
" <td>5.028574e+07</td>\n",
" <td>25.20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>3879239</td>\n",
" <td>0.012722</td>\n",
" <td>0.028447</td>\n",
" <td>26477.211836</td>\n",
" <td>1.683201e+06</td>\n",
" <td>26.60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3879244</td>\n",
" <td>0.013322</td>\n",
" <td>0.029788</td>\n",
" <td>33770.201397</td>\n",
" <td>3.943220e+06</td>\n",
" <td>26.50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2481965</td>\n",
" <td>0.193444</td>\n",
" <td>0.432554</td>\n",
" <td>43599.575296</td>\n",
" <td>7.346837e+07</td>\n",
" <td>20.69</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3789471</td>\n",
" <td>0.044112</td>\n",
" <td>0.098637</td>\n",
" <td>36398.080883</td>\n",
" <td>6.352916e+07</td>\n",
" <td>23.90</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4995</th>\n",
" <td>3468663</td>\n",
" <td>0.006677</td>\n",
" <td>0.014929</td>\n",
" <td>20300.398051</td>\n",
" <td>1.700006e+06</td>\n",
" <td>28.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4996</th>\n",
" <td>3620670</td>\n",
" <td>0.105817</td>\n",
" <td>0.236614</td>\n",
" <td>36514.062162</td>\n",
" <td>6.945396e+07</td>\n",
" <td>22.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4997</th>\n",
" <td>3562321</td>\n",
" <td>0.192555</td>\n",
" <td>0.430566</td>\n",
" <td>68895.907750</td>\n",
" <td>5.209557e+07</td>\n",
" <td>20.70</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4998</th>\n",
" <td>3440771</td>\n",
" <td>0.253837</td>\n",
" <td>0.567597</td>\n",
" <td>61336.513568</td>\n",
" <td>5.037204e+07</td>\n",
" <td>20.10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4999</th>\n",
" <td>54065901</td>\n",
" <td>0.015295</td>\n",
" <td>0.034201</td>\n",
" <td>18389.028188</td>\n",
" <td>5.627145e+07</td>\n",
" <td>26.20</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5000 rows × 6 columns</p>\n",
"</div>"
],
"text/plain": [
" id est_diameter_min est_diameter_max relative_velocity \\\n",
"0 3943344 0.024241 0.054205 22148.962596 \n",
"1 3879239 0.012722 0.028447 26477.211836 \n",
"2 3879244 0.013322 0.029788 33770.201397 \n",
"3 2481965 0.193444 0.432554 43599.575296 \n",
"4 3789471 0.044112 0.098637 36398.080883 \n",
"... ... ... ... ... \n",
"4995 3468663 0.006677 0.014929 20300.398051 \n",
"4996 3620670 0.105817 0.236614 36514.062162 \n",
"4997 3562321 0.192555 0.430566 68895.907750 \n",
"4998 3440771 0.253837 0.567597 61336.513568 \n",
"4999 54065901 0.015295 0.034201 18389.028188 \n",
"\n",
" miss_distance absolute_magnitude \n",
"0 5.028574e+07 25.20 \n",
"1 1.683201e+06 26.60 \n",
"2 3.943220e+06 26.50 \n",
"3 7.346837e+07 20.69 \n",
"4 6.352916e+07 23.90 \n",
"... ... ... \n",
"4995 1.700006e+06 28.00 \n",
"4996 6.945396e+07 22.00 \n",
"4997 5.209557e+07 20.70 \n",
"4998 5.037204e+07 20.10 \n",
"4999 5.627145e+07 26.20 \n",
"\n",
"[5000 rows x 6 columns]"
]
},
"execution_count": 201,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import set_config\n",
"\n",
"random_state=42\n",
"set_config(transform_output=\"pandas\")\n",
"df = pd.read_csv(\".//static//csv//neo.csv\")\n",
"# Удаление столбцов \"sentry_object\" и \"hazardous\"\n",
"df = df.drop(columns=[\"sentry_object\", \"hazardous\", \"orbiting_body\", \"name\"])\n",
"\n",
"# Ограничение количества записей до 5,000\n",
"df = df.sample(n=5000, random_state=random_state).reset_index(drop=True)\n",
"\n",
"# Проверка итогового DataFrame\n",
"print(df.shape) # Убедитесь, что размер 5,000 строк\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 202,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" est_diameter_min est_diameter_max relative_velocity miss_distance \\\n",
"0 1.198271 2.679415 13569.249224 5.483974e+07 \n",
"1 0.265800 0.594347 73588.726663 6.143813e+07 \n",
"2 0.722030 1.614507 114258.692129 4.979872e+07 \n",
"3 0.096506 0.215794 24764.303138 2.543497e+07 \n",
"4 0.255009 0.570217 42737.733765 4.627557e+07 \n",
"\n",
" impact_damage_index \n",
"0 0.000480 \n",
"1 0.000515 \n",
"2 0.002680 \n",
"3 0.000152 \n",
"4 0.000381 \n"
]
}
],
"source": [
"import pandas as pd\n",
"\n",
"# Загрузка данных (замените путь на актуальный, если требуется)\n",
"df = pd.read_csv(\".//static//csv//neo.csv\")\n",
"\n",
"# Убедитесь, что столбцы в данных содержат необходимые характеристики\n",
"required_columns = [\"est_diameter_min\", \"est_diameter_max\", \"relative_velocity\", \"miss_distance\"]\n",
"missing_columns = [col for col in required_columns if col not in df.columns]\n",
"if missing_columns:\n",
" raise ValueError(f\"Отсутствуют столбцы: {missing_columns}\")\n",
"\n",
"# Создание переменной \"impact_damage_index\"\n",
"# Формула, используемая ниже, условная и может быть скорректирована в зависимости от анализа\n",
"# Пример: чем больше средний диаметр и скорость, тем выше ущерб. Чем больше расстояние, тем ниже ущерб.\n",
"df[\"impact_damage_index\"] = (\n",
" (df[\"est_diameter_min\"] + df[\"est_diameter_max\"]) / 2 # Средний диаметр\n",
" * df[\"relative_velocity\"] # Скорость\n",
" / df[\"miss_distance\"] # Обратная зависимость от расстояния\n",
")\n",
"\n",
"# Проверка новых данных\n",
"print(df[[\"est_diameter_min\", \"est_diameter_max\", \"relative_velocity\", \"miss_distance\", \"impact_damage_index\"]].head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи регрессии "
]
},
{
"cell_type": "code",
"execution_count": 203,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>absolute_magnitude</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>35538</th>\n",
" <td>3826685</td>\n",
" <td>0.038420</td>\n",
" <td>0.085909</td>\n",
" <td>91103.489666</td>\n",
" <td>6.350550e+07</td>\n",
" <td>24.20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40393</th>\n",
" <td>2277830</td>\n",
" <td>0.192555</td>\n",
" <td>0.430566</td>\n",
" <td>28359.611312</td>\n",
" <td>2.868167e+07</td>\n",
" <td>20.70</td>\n",
" </tr>\n",
" <tr>\n",
" <th>58540</th>\n",
" <td>3638201</td>\n",
" <td>0.004619</td>\n",
" <td>0.010329</td>\n",
" <td>107351.426865</td>\n",
" <td>5.388098e+04</td>\n",
" <td>28.80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61670</th>\n",
" <td>3836282</td>\n",
" <td>0.015295</td>\n",
" <td>0.034201</td>\n",
" <td>21423.536884</td>\n",
" <td>5.103884e+07</td>\n",
" <td>26.20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11435</th>\n",
" <td>3802002</td>\n",
" <td>0.011603</td>\n",
" <td>0.025944</td>\n",
" <td>69856.053840</td>\n",
" <td>7.360836e+07</td>\n",
" <td>26.80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6265</th>\n",
" <td>2530151</td>\n",
" <td>0.211132</td>\n",
" <td>0.472106</td>\n",
" <td>88209.754856</td>\n",
" <td>4.034289e+07</td>\n",
" <td>20.50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>54886</th>\n",
" <td>3831736</td>\n",
" <td>0.035039</td>\n",
" <td>0.078350</td>\n",
" <td>58758.452153</td>\n",
" <td>4.389994e+06</td>\n",
" <td>24.40</td>\n",
" </tr>\n",
" <tr>\n",
" <th>76820</th>\n",
" <td>2512234</td>\n",
" <td>0.211132</td>\n",
" <td>0.472106</td>\n",
" <td>52355.509176</td>\n",
" <td>4.380532e+07</td>\n",
" <td>20.50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>860</th>\n",
" <td>54054466</td>\n",
" <td>0.282199</td>\n",
" <td>0.631015</td>\n",
" <td>50527.379563</td>\n",
" <td>5.837007e+07</td>\n",
" <td>19.87</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15795</th>\n",
" <td>3773929</td>\n",
" <td>0.075258</td>\n",
" <td>0.168283</td>\n",
" <td>22527.647871</td>\n",
" <td>2.281469e+07</td>\n",
" <td>22.74</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>72668 rows × 6 columns</p>\n",
"</div>"
],
"text/plain": [
" id est_diameter_min est_diameter_max relative_velocity \\\n",
"35538 3826685 0.038420 0.085909 91103.489666 \n",
"40393 2277830 0.192555 0.430566 28359.611312 \n",
"58540 3638201 0.004619 0.010329 107351.426865 \n",
"61670 3836282 0.015295 0.034201 21423.536884 \n",
"11435 3802002 0.011603 0.025944 69856.053840 \n",
"... ... ... ... ... \n",
"6265 2530151 0.211132 0.472106 88209.754856 \n",
"54886 3831736 0.035039 0.078350 58758.452153 \n",
"76820 2512234 0.211132 0.472106 52355.509176 \n",
"860 54054466 0.282199 0.631015 50527.379563 \n",
"15795 3773929 0.075258 0.168283 22527.647871 \n",
"\n",
" miss_distance absolute_magnitude \n",
"35538 6.350550e+07 24.20 \n",
"40393 2.868167e+07 20.70 \n",
"58540 5.388098e+04 28.80 \n",
"61670 5.103884e+07 26.20 \n",
"11435 7.360836e+07 26.80 \n",
"... ... ... \n",
"6265 4.034289e+07 20.50 \n",
"54886 4.389994e+06 24.40 \n",
"76820 4.380532e+07 20.50 \n",
"860 5.837007e+07 19.87 \n",
"15795 2.281469e+07 22.74 \n",
"\n",
"[72668 rows x 6 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>impact_damage_index</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>35538</th>\n",
" <td>0.000089</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40393</th>\n",
" <td>0.000308</td>\n",
" </tr>\n",
" <tr>\n",
" <th>58540</th>\n",
" <td>0.014891</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61670</th>\n",
" <td>0.000010</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11435</th>\n",
" <td>0.000018</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6265</th>\n",
" <td>0.000747</td>\n",
" </tr>\n",
" <tr>\n",
" <th>54886</th>\n",
" <td>0.000759</td>\n",
" </tr>\n",
" <tr>\n",
" <th>76820</th>\n",
" <td>0.000408</td>\n",
" </tr>\n",
" <tr>\n",
" <th>860</th>\n",
" <td>0.000395</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15795</th>\n",
" <td>0.000120</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>72668 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" impact_damage_index\n",
"35538 0.000089\n",
"40393 0.000308\n",
"58540 0.014891\n",
"61670 0.000010\n",
"11435 0.000018\n",
"... ...\n",
"6265 0.000747\n",
"54886 0.000759\n",
"76820 0.000408\n",
"860 0.000395\n",
"15795 0.000120\n",
"\n",
"[72668 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>est_diameter_min</th>\n",
" <th>est_diameter_max</th>\n",
" <th>relative_velocity</th>\n",
" <th>miss_distance</th>\n",
" <th>absolute_magnitude</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>20406</th>\n",
" <td>3943344</td>\n",
" <td>0.024241</td>\n",
" <td>0.054205</td>\n",
" <td>22148.962596</td>\n",
" <td>5.028574e+07</td>\n",
" <td>25.20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>74443</th>\n",
" <td>3879239</td>\n",
" <td>0.012722</td>\n",
" <td>0.028447</td>\n",
" <td>26477.211836</td>\n",
" <td>1.683201e+06</td>\n",
" <td>26.60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>74306</th>\n",
" <td>3879244</td>\n",
" <td>0.013322</td>\n",
" <td>0.029788</td>\n",
" <td>33770.201397</td>\n",
" <td>3.943220e+06</td>\n",
" <td>26.50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>45943</th>\n",
" <td>2481965</td>\n",
" <td>0.193444</td>\n",
" <td>0.432554</td>\n",
" <td>43599.575296</td>\n",
" <td>7.346837e+07</td>\n",
" <td>20.69</td>\n",
" </tr>\n",
" <tr>\n",
" <th>62859</th>\n",
" <td>3789471</td>\n",
" <td>0.044112</td>\n",
" <td>0.098637</td>\n",
" <td>36398.080883</td>\n",
" <td>6.352916e+07</td>\n",
" <td>23.90</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>51634</th>\n",
" <td>3694131</td>\n",
" <td>0.008801</td>\n",
" <td>0.019681</td>\n",
" <td>57414.305699</td>\n",
" <td>1.987273e+07</td>\n",
" <td>27.40</td>\n",
" </tr>\n",
" <tr>\n",
" <th>85083</th>\n",
" <td>54235475</td>\n",
" <td>0.024920</td>\n",
" <td>0.055724</td>\n",
" <td>50882.935767</td>\n",
" <td>3.119646e+07</td>\n",
" <td>25.14</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38905</th>\n",
" <td>3775176</td>\n",
" <td>0.008405</td>\n",
" <td>0.018795</td>\n",
" <td>24954.754212</td>\n",
" <td>1.111942e+07</td>\n",
" <td>27.50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16144</th>\n",
" <td>2434734</td>\n",
" <td>0.265800</td>\n",
" <td>0.594347</td>\n",
" <td>57455.404666</td>\n",
" <td>8.501684e+06</td>\n",
" <td>20.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>54508</th>\n",
" <td>3170208</td>\n",
" <td>0.023150</td>\n",
" <td>0.051765</td>\n",
" <td>72602.093427</td>\n",
" <td>4.624727e+07</td>\n",
" <td>25.30</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>18168 rows × 6 columns</p>\n",
"</div>"
],
"text/plain": [
" id est_diameter_min est_diameter_max relative_velocity \\\n",
"20406 3943344 0.024241 0.054205 22148.962596 \n",
"74443 3879239 0.012722 0.028447 26477.211836 \n",
"74306 3879244 0.013322 0.029788 33770.201397 \n",
"45943 2481965 0.193444 0.432554 43599.575296 \n",
"62859 3789471 0.044112 0.098637 36398.080883 \n",
"... ... ... ... ... \n",
"51634 3694131 0.008801 0.019681 57414.305699 \n",
"85083 54235475 0.024920 0.055724 50882.935767 \n",
"38905 3775176 0.008405 0.018795 24954.754212 \n",
"16144 2434734 0.265800 0.594347 57455.404666 \n",
"54508 3170208 0.023150 0.051765 72602.093427 \n",
"\n",
" miss_distance absolute_magnitude \n",
"20406 5.028574e+07 25.20 \n",
"74443 1.683201e+06 26.60 \n",
"74306 3.943220e+06 26.50 \n",
"45943 7.346837e+07 20.69 \n",
"62859 6.352916e+07 23.90 \n",
"... ... ... \n",
"51634 1.987273e+07 27.40 \n",
"85083 3.119646e+07 25.14 \n",
"38905 1.111942e+07 27.50 \n",
"16144 8.501684e+06 20.00 \n",
"54508 4.624727e+07 25.30 \n",
"\n",
"[18168 rows x 6 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>impact_damage_index</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>20406</th>\n",
" <td>0.000017</td>\n",
" </tr>\n",
" <tr>\n",
" <th>74443</th>\n",
" <td>0.000324</td>\n",
" </tr>\n",
" <tr>\n",
" <th>74306</th>\n",
" <td>0.000185</td>\n",
" </tr>\n",
" <tr>\n",
" <th>45943</th>\n",
" <td>0.000186</td>\n",
" </tr>\n",
" <tr>\n",
" <th>62859</th>\n",
" <td>0.000041</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>51634</th>\n",
" <td>0.000041</td>\n",
" </tr>\n",
" <tr>\n",
" <th>85083</th>\n",
" <td>0.000066</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38905</th>\n",
" <td>0.000031</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16144</th>\n",
" <td>0.002906</td>\n",
" </tr>\n",
" <tr>\n",
" <th>54508</th>\n",
" <td>0.000059</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>18168 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" impact_damage_index\n",
"20406 0.000017\n",
"74443 0.000324\n",
"74306 0.000185\n",
"45943 0.000186\n",
"62859 0.000041\n",
"... ...\n",
"51634 0.000041\n",
"85083 0.000066\n",
"38905 0.000031\n",
"16144 0.002906\n",
"54508 0.000059\n",
"\n",
"[18168 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_into_train_test(\n",
" df_input: DataFrame,\n",
" target_colname: str = \"impact_damage_index\",\n",
" frac_train: float = 0.8,\n",
" random_state: int = None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if not (0 < frac_train < 1):\n",
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
" \n",
" # Проверка наличия целевого признака\n",
" if target_colname not in df_input.columns:\n",
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
" \n",
" # Разделяем данные на признаки и целевую переменную\n",
" X = df_input.drop(columns=[target_colname]) # Признаки\n",
" y = df_input[[target_colname]] # Целевая переменная\n",
"\n",
" # Удаляем указанные столбцы из X\n",
" columns_to_remove = [\"sentry_object\", \"hazardous\", \"orbiting_body\", \"name\"]\n",
" X = X.drop(columns=columns_to_remove, errors='ignore') # Игнорировать ошибку, если столбцы не найдены\n",
"\n",
" # Разделяем данные на обучающую и тестовую выборки\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y,\n",
" test_size=(1.0 - frac_train),\n",
" random_state=random_state\n",
" )\n",
" \n",
" return X_train, X_test, y_train, y_test\n",
"\n",
"# Применение функции для разделения данных\n",
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
" df, \n",
" target_colname=\"impact_damage_index\", \n",
" frac_train=0.8, \n",
" random_state=42\n",
")\n",
"\n",
"# Для отображения результатов\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Определение перечня алгоритмов решения задачи аппроксимации (регрессии)"
]
},
{
"cell_type": "code",
"execution_count": 204,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
"\n",
"random_state = 9\n",
"\n",
"models = {\n",
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPRegressor(\n",
" activation=\"tanh\",\n",
" hidden_layer_sizes=(3,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование набора моделей для регрессии"
]
},
{
"cell_type": "code",
"execution_count": 205,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: linear\n",
"Model: linear_poly\n",
"Model: linear_interact\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import math\n",
"from pandas import DataFrame\n",
"from sklearn import metrics\n",
"\n",
"for model_name in models.keys():\n",
" print(f\"Model: {model_name}\")\n",
"\n",
" fitted_model = models[model_name][\"model\"].fit(\n",
" X_train.values, y_train.values.ravel()\n",
" )\n",
" y_train_pred = fitted_model.predict(X_train.values)\n",
" y_test_pred = fitted_model.predict(X_test.values)\n",
" models[model_name][\"fitted\"] = fitted_model\n",
" models[model_name][\"train_preds\"] = y_train_pred\n",
" models[model_name][\"preds\"] = y_test_pred\n",
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_train, y_train_pred)\n",
" )\n",
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод результатов оценки"
]
},
{
"cell_type": "code",
"execution_count": 206,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_38ff3_row0_col0, #T_38ff3_row0_col1, #T_38ff3_row1_col0, #T_38ff3_row1_col1, #T_38ff3_row2_col0, #T_38ff3_row2_col1, #T_38ff3_row3_col0, #T_38ff3_row3_col1, #T_38ff3_row4_col0, #T_38ff3_row4_col1, #T_38ff3_row5_col0, #T_38ff3_row5_col1, #T_38ff3_row6_col0, #T_38ff3_row6_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_38ff3_row0_col2, #T_38ff3_row1_col2, #T_38ff3_row2_col2, #T_38ff3_row3_col2, #T_38ff3_row4_col2, #T_38ff3_row5_col2, #T_38ff3_row7_col3 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_38ff3_row0_col3, #T_38ff3_row1_col3, #T_38ff3_row2_col3, #T_38ff3_row3_col3, #T_38ff3_row4_col3, #T_38ff3_row5_col3, #T_38ff3_row6_col3, #T_38ff3_row7_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_38ff3_row6_col2 {\n",
" background-color: #5002a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_38ff3_row7_col0, #T_38ff3_row7_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"</style>\n",
"<table id=\"T_38ff3\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_38ff3_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
" <th id=\"T_38ff3_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
" <th id=\"T_38ff3_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
" <th id=\"T_38ff3_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_38ff3_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
" <td id=\"T_38ff3_row0_col0\" class=\"data row0 col0\" >0.000409</td>\n",
" <td id=\"T_38ff3_row0_col1\" class=\"data row0 col1\" >0.000711</td>\n",
" <td id=\"T_38ff3_row0_col2\" class=\"data row0 col2\" >0.012593</td>\n",
" <td id=\"T_38ff3_row0_col3\" class=\"data row0 col3\" >0.852564</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_38ff3_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
" <td id=\"T_38ff3_row1_col0\" class=\"data row1 col0\" >0.000511</td>\n",
" <td id=\"T_38ff3_row1_col1\" class=\"data row1 col1\" >0.001031</td>\n",
" <td id=\"T_38ff3_row1_col2\" class=\"data row1 col2\" >0.015170</td>\n",
" <td id=\"T_38ff3_row1_col3\" class=\"data row1 col3\" >0.689858</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_38ff3_level0_row2\" class=\"row_heading level0 row2\" >linear_poly</th>\n",
" <td id=\"T_38ff3_row2_col0\" class=\"data row2 col0\" >0.001217</td>\n",
" <td id=\"T_38ff3_row2_col1\" class=\"data row2 col1\" >0.001476</td>\n",
" <td id=\"T_38ff3_row2_col2\" class=\"data row2 col2\" >0.018001</td>\n",
" <td id=\"T_38ff3_row2_col3\" class=\"data row2 col3\" >0.364795</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_38ff3_level0_row3\" class=\"row_heading level0 row3\" >linear_interact</th>\n",
" <td id=\"T_38ff3_row3_col0\" class=\"data row3 col0\" >0.001263</td>\n",
" <td id=\"T_38ff3_row3_col1\" class=\"data row3 col1\" >0.001500</td>\n",
" <td id=\"T_38ff3_row3_col2\" class=\"data row3 col2\" >0.018235</td>\n",
" <td id=\"T_38ff3_row3_col3\" class=\"data row3 col3\" >0.343354</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_38ff3_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
" <td id=\"T_38ff3_row4_col0\" class=\"data row4 col0\" >0.001206</td>\n",
" <td id=\"T_38ff3_row4_col1\" class=\"data row4 col1\" >0.001611</td>\n",
" <td id=\"T_38ff3_row4_col2\" class=\"data row4 col2\" >0.019245</td>\n",
" <td id=\"T_38ff3_row4_col3\" class=\"data row4 col3\" >0.243014</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_38ff3_level0_row5\" class=\"row_heading level0 row5\" >linear</th>\n",
" <td id=\"T_38ff3_row5_col0\" class=\"data row5 col0\" >0.001382</td>\n",
" <td id=\"T_38ff3_row5_col1\" class=\"data row5 col1\" >0.001629</td>\n",
" <td id=\"T_38ff3_row5_col2\" class=\"data row5 col2\" >0.019724</td>\n",
" <td id=\"T_38ff3_row5_col3\" class=\"data row5 col3\" >0.225851</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_38ff3_level0_row6\" class=\"row_heading level0 row6\" >mlp</th>\n",
" <td id=\"T_38ff3_row6_col0\" class=\"data row6 col0\" >0.001610</td>\n",
" <td id=\"T_38ff3_row6_col1\" class=\"data row6 col1\" >0.001852</td>\n",
" <td id=\"T_38ff3_row6_col2\" class=\"data row6 col2\" >0.023283</td>\n",
" <td id=\"T_38ff3_row6_col3\" class=\"data row6 col3\" >-0.000074</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_38ff3_level0_row7\" class=\"row_heading level0 row7\" >ridge</th>\n",
" <td id=\"T_38ff3_row7_col0\" class=\"data row7 col0\" >2.251826</td>\n",
" <td id=\"T_38ff3_row7_col1\" class=\"data row7 col1\" >2.248301</td>\n",
" <td id=\"T_38ff3_row7_col2\" class=\"data row7 col2\" >1.349327</td>\n",
" <td id=\"T_38ff3_row7_col3\" class=\"data row7 col3\" >-1474534.430780</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1b4c88a89e0>"
]
},
"execution_count": 206,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вывод реального и \"спрогнозированного\" результата для обучающей и тестовой выборок\n",
"\n",
"Получение лучшей модели"
]
},
{
"cell_type": "code",
"execution_count": 207,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'random_forest'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
"execution_count": 209,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 8 candidates, totalling 24 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие параметры: {'max_depth': 10, 'min_samples_split': 5, 'n_estimators': 50}\n",
"Лучший результат (MSE): 5.418559949534169e-07\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"from sklearn.ensemble import RandomForestRegressor # Используем регрессор\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"\n",
"df.dropna(inplace=True) \n",
"# Предикторы и целевая переменная\n",
"X = df[[\"est_diameter_min\", \"est_diameter_max\", \"relative_velocity\", \"miss_distance\", \"absolute_magnitude\"]]\n",
"y = df['impact_damage_index'] # Целевая переменная для регрессии\n",
"\n",
"\n",
"model = RandomForestRegressor() \n",
"\n",
"param_grid = {\n",
" 'n_estimators': [50, 100], \n",
" 'max_depth': [10, 20], \n",
" 'min_samples_split': [5, 10] \n",
"}\n",
"\n",
"# 3. Подбор гиперпараметров с помощью Grid Search\n",
"grid_search = GridSearchCV(estimator=model, param_grid=param_grid,\n",
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"grid_search.fit(X_train, y_train)\n",
"\n",
"# 4. Результаты подбора гиперпараметров\n",
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
"print(\"Лучший результат (MSE):\", -grid_search.best_score_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами и сравнение новых и старых данных"
]
},
{
"cell_type": "code",
"execution_count": 210,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 8 candidates, totalling 24 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Старые параметры: {'max_depth': 20, 'min_samples_split': 5, 'n_estimators': 50}\n",
"Лучший результат (MSE) на старых параметрах: 5.299415148966497e-07\n",
"\n",
"Новые параметры: {'max_depth': 20, 'min_samples_split': 10, 'n_estimators': 100}\n",
"Лучший результат (MSE) на новых параметрах: 5.355742455463778e-07\n",
"Среднеквадратическая ошибка (MSE) на тестовых данных: 4.772832137780905e-07\n",
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 0.0006908568692414446\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"old_param_grid = {\n",
" 'n_estimators': [50, 100], # Количество деревьев\n",
" 'max_depth': [ 10, 20], # Максимальная глубина дерева\n",
" 'min_samples_split': [5, 10] # Минимальное количество образцов для разбиения узла\n",
"}\n",
"\n",
"old_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
" param_grid=old_param_grid,\n",
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
"\n",
"old_grid_search.fit(X_train, y_train)\n",
"\n",
"old_best_params = old_grid_search.best_params_\n",
"old_best_mse = -old_grid_search.best_score_ # Меняем знак, так как берем отрицательное значение MSE\n",
"\n",
"new_param_grid = {\n",
" 'n_estimators': [100],\n",
" 'max_depth': [20],\n",
" 'min_samples_split': [10]\n",
"}\n",
"\n",
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
" param_grid=new_param_grid,\n",
" scoring='neg_mean_squared_error', cv=2)\n",
"\n",
"new_grid_search.fit(X_train, y_train)\n",
"\n",
"new_best_params = new_grid_search.best_params_\n",
"new_best_mse = -new_grid_search.best_score_ # Меняем знак, так как берем отрицательное значение MSE\n",
"\n",
"model_best = RandomForestRegressor(**new_best_params)\n",
"model_best.fit(X_train, y_train)\n",
"\n",
"model_oldbest = RandomForestRegressor(**old_best_params)\n",
"model_oldbest.fit(X_train, y_train)\n",
"\n",
"y_pred = model_best.predict(X_test)\n",
"y_oldpred = model_oldbest.predict(X_test)\n",
"\n",
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"\n",
"print(\"Старые параметры:\", old_best_params)\n",
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
"print(\"\\nНовые параметры:\", new_best_params)\n",
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Попробуем визуализировать"
]
},
{
"cell_type": "code",
"execution_count": 212,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10, 5))\n",
"plt.scatter(range(len(y_test)), y_test, label=\"Актуалочка\", color=\"black\", alpha=0.5)\n",
"plt.scatter(range(len(y_test)), y_pred, label=\"Предсказанные(новые параметры)\", color=\"blue\", alpha=0.5)\n",
"plt.scatter(range(len(y_test)), y_oldpred, label=\"Предсказанные(старые параметры)\", color=\"red\", alpha=0.5)\n",
"plt.xlabel(\"Выборка\")\n",
"plt.ylabel(\"Значения\")\n",
"plt.legend()\n",
"plt.title(\"Актуалочка vs Предсказанных значений (Новые and Старые Параметры)\")\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "aimenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}