{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Начало 4-й лабораторной\n", "#### Ближайшие объекты к Земле" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Index(['id', 'name', 'est_diameter_min', 'est_diameter_max',\n", " 'relative_velocity', 'miss_distance', 'orbiting_body', 'sentry_object',\n", " 'absolute_magnitude', 'hazardous'],\n", " dtype='object')\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idnameest_diameter_minest_diameter_maxrelative_velocitymiss_distanceorbiting_bodysentry_objectabsolute_magnitudehazardous
02162635162635 (2000 SS164)1.1982712.67941513569.2492245.483974e+07EarthFalse16.73False
12277475277475 (2005 WK4)0.2658000.59434773588.7266636.143813e+07EarthFalse20.00True
22512244512244 (2015 YE18)0.7220301.614507114258.6921294.979872e+07EarthFalse17.83False
33596030(2012 BV13)0.0965060.21579424764.3031382.543497e+07EarthFalse22.20False
43667127(2014 GE35)0.2550090.57021742737.7337654.627557e+07EarthFalse20.09True
.................................
908313763337(2016 VX1)0.0265800.05943552078.8866921.230039e+07EarthFalse25.00False
908323837603(2019 AD3)0.0167710.03750146114.6050735.432121e+07EarthFalse26.00False
9083354017201(2020 JP3)0.0319560.0714567566.8077322.840077e+07EarthFalse24.60False
9083454115824(2021 CN5)0.0073210.01637069199.1544846.869206e+07EarthFalse27.80False
9083554205447(2021 TW7)0.0398620.08913327024.4555535.977213e+07EarthFalse24.12False
\n", "

90836 rows × 10 columns

\n", "
" ], "text/plain": [ " id name est_diameter_min est_diameter_max \\\n", "0 2162635 162635 (2000 SS164) 1.198271 2.679415 \n", "1 2277475 277475 (2005 WK4) 0.265800 0.594347 \n", "2 2512244 512244 (2015 YE18) 0.722030 1.614507 \n", "3 3596030 (2012 BV13) 0.096506 0.215794 \n", "4 3667127 (2014 GE35) 0.255009 0.570217 \n", "... ... ... ... ... \n", "90831 3763337 (2016 VX1) 0.026580 0.059435 \n", "90832 3837603 (2019 AD3) 0.016771 0.037501 \n", "90833 54017201 (2020 JP3) 0.031956 0.071456 \n", "90834 54115824 (2021 CN5) 0.007321 0.016370 \n", "90835 54205447 (2021 TW7) 0.039862 0.089133 \n", "\n", " relative_velocity miss_distance orbiting_body sentry_object \\\n", "0 13569.249224 5.483974e+07 Earth False \n", "1 73588.726663 6.143813e+07 Earth False \n", "2 114258.692129 4.979872e+07 Earth False \n", "3 24764.303138 2.543497e+07 Earth False \n", "4 42737.733765 4.627557e+07 Earth False \n", "... ... ... ... ... \n", "90831 52078.886692 1.230039e+07 Earth False \n", "90832 46114.605073 5.432121e+07 Earth False \n", "90833 7566.807732 2.840077e+07 Earth False \n", "90834 69199.154484 6.869206e+07 Earth False \n", "90835 27024.455553 5.977213e+07 Earth False \n", "\n", " absolute_magnitude hazardous \n", "0 16.73 False \n", "1 20.00 True \n", "2 17.83 False \n", "3 22.20 False \n", "4 20.09 True \n", "... ... ... \n", "90831 25.00 False \n", "90832 26.00 False \n", "90833 24.60 False \n", "90834 27.80 False \n", "90835 24.12 False \n", "\n", "[90836 rows x 10 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from sklearn.model_selection import train_test_split\n", "from sklearn import set_config\n", "\n", "set_config(transform_output=\"pandas\")\n", "df = pd.read_csv(\".//static//csv//neo.csv\")\n", "print(df.columns)\n", "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Бизнес-цели:\n", "\n", "1. Идентификация потенциально опасных объектов\n", "\n", "Описание: классифицировать астероиды как потенциально опасные или безопасные (используя целевой признак \"hazardous\"). Эта задача актуальна для оценки рисков и подготовки соответствующих действий по защите Земли.\n", "\n", "2. Прогнозирование минимального расстояния до Земли\n", "\n", "Описание: предсказать минимальное расстояние до Земли для новых объектов на основе характеристик астероида (скорости, размера и других параметров). Это позволит планировать исследования и наблюдения в зависимости от опасности. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Определение достижимого уровня качества модели для первой задачи " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n", "\n", "Целевой признак -- hazardous" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'X_train'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idnameest_diameter_minest_diameter_maxrelative_velocitymiss_distanceorbiting_bodysentry_objectabsolute_magnitudehazardous
26393634614(2013 GT66)0.0242410.05420543303.9990944.814117e+07EarthFalse25.20False
2913854143560(2021 JU1)0.0302380.06761521770.7902115.646643e+07EarthFalse24.72False
369273836085(2018 VQ3)0.2016300.450858109358.1230296.435051e+07EarthFalse20.60False
618553769804(2017 DJ34)0.1601600.35812978494.6097565.595780e+07EarthFalse21.10False
159163824978(2018 KS)0.0069910.01563319077.7494863.834648e+07EarthFalse27.90False
.................................
294913827304(2018 RR1)0.0026580.00594319826.8958803.852881e+07EarthFalse30.00False
183733735468(2015 WY1)0.1034080.23122882856.5449267.314334e+07EarthFalse22.05False
250313802041(2018 FE3)0.0096510.02157934243.7742014.257719e+07EarthFalse27.20False
354563430406(2008 TR10)0.2210830.49435619557.2897832.152970e+07EarthFalse20.40False
143053285300(2005 OG3)0.2982330.66686820309.4047061.770015e+07EarthFalse19.75False
\n", "

72668 rows × 10 columns

\n", "
" ], "text/plain": [ " id name est_diameter_min est_diameter_max \\\n", "2639 3634614 (2013 GT66) 0.024241 0.054205 \n", "29138 54143560 (2021 JU1) 0.030238 0.067615 \n", "36927 3836085 (2018 VQ3) 0.201630 0.450858 \n", "61855 3769804 (2017 DJ34) 0.160160 0.358129 \n", "15916 3824978 (2018 KS) 0.006991 0.015633 \n", "... ... ... ... ... \n", "29491 3827304 (2018 RR1) 0.002658 0.005943 \n", "18373 3735468 (2015 WY1) 0.103408 0.231228 \n", "25031 3802041 (2018 FE3) 0.009651 0.021579 \n", "35456 3430406 (2008 TR10) 0.221083 0.494356 \n", "14305 3285300 (2005 OG3) 0.298233 0.666868 \n", "\n", " relative_velocity miss_distance orbiting_body sentry_object \\\n", "2639 43303.999094 4.814117e+07 Earth False \n", "29138 21770.790211 5.646643e+07 Earth False \n", "36927 109358.123029 6.435051e+07 Earth False \n", "61855 78494.609756 5.595780e+07 Earth False \n", "15916 19077.749486 3.834648e+07 Earth False \n", "... ... ... ... ... \n", "29491 19826.895880 3.852881e+07 Earth False \n", "18373 82856.544926 7.314334e+07 Earth False \n", "25031 34243.774201 4.257719e+07 Earth False \n", "35456 19557.289783 2.152970e+07 Earth False \n", "14305 20309.404706 1.770015e+07 Earth False \n", "\n", " absolute_magnitude hazardous \n", "2639 25.20 False \n", "29138 24.72 False \n", "36927 20.60 False \n", "61855 21.10 False \n", "15916 27.90 False \n", "... ... ... \n", "29491 30.00 False \n", "18373 22.05 False \n", "25031 27.20 False \n", "35456 20.40 False \n", "14305 19.75 False \n", "\n", "[72668 rows x 10 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'y_train'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
hazardous
2639False
29138False
36927False
61855False
15916False
......
29491False
18373False
25031False
35456False
14305False
\n", "

72668 rows × 1 columns

\n", "
" ], "text/plain": [ " hazardous\n", "2639 False\n", "29138 False\n", "36927 False\n", "61855 False\n", "15916 False\n", "... ...\n", "29491 False\n", "18373 False\n", "25031 False\n", "35456 False\n", "14305 False\n", "\n", "[72668 rows x 1 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'X_test'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idnameest_diameter_minest_diameter_maxrelative_velocitymiss_distanceorbiting_bodysentry_objectabsolute_magnitudehazardous
90402474532474532 (2003 VG1)0.4726671.05691521779.2371373.443050e+07EarthFalse18.75False
673053774018(2017 HF1)0.0840530.18794953291.0162266.862591e+07EarthFalse22.50False
7774154269585(2022 GQ2)0.0182200.04074243089.0464332.592726e+07EarthFalse25.82False
8152054097970(2020 XS)0.1529520.34201193246.4555994.709054e+07EarthFalse21.20False
5083730802(2015 TT238)0.0319560.07145637708.2585444.232149e+07EarthFalse24.60False
.................................
282613532365(2010 MH1)0.1394940.31191837604.9802387.369507e+07EarthFalse21.40False
115954073345(2020 UE)0.0207280.04634936720.0777283.366114e+07EarthFalse25.54False
480953836195(2018 VT7)0.0069910.0156337616.4965356.376350e+06EarthFalse27.90False
902343752902(2016 JG12)0.0840530.18794921894.5546925.736984e+07EarthFalse22.50False
120133445077(2009 BM58)0.0384200.08590949828.6116094.305599e+07EarthFalse24.20False
\n", "

18168 rows × 10 columns

\n", "
" ], "text/plain": [ " id name est_diameter_min est_diameter_max \\\n", "9040 2474532 474532 (2003 VG1) 0.472667 1.056915 \n", "67305 3774018 (2017 HF1) 0.084053 0.187949 \n", "77741 54269585 (2022 GQ2) 0.018220 0.040742 \n", "81520 54097970 (2020 XS) 0.152952 0.342011 \n", "508 3730802 (2015 TT238) 0.031956 0.071456 \n", "... ... ... ... ... \n", "28261 3532365 (2010 MH1) 0.139494 0.311918 \n", "1159 54073345 (2020 UE) 0.020728 0.046349 \n", "48095 3836195 (2018 VT7) 0.006991 0.015633 \n", "90234 3752902 (2016 JG12) 0.084053 0.187949 \n", "12013 3445077 (2009 BM58) 0.038420 0.085909 \n", "\n", " relative_velocity miss_distance orbiting_body sentry_object \\\n", "9040 21779.237137 3.443050e+07 Earth False \n", "67305 53291.016226 6.862591e+07 Earth False \n", "77741 43089.046433 2.592726e+07 Earth False \n", "81520 93246.455599 4.709054e+07 Earth False \n", "508 37708.258544 4.232149e+07 Earth False \n", "... ... ... ... ... \n", "28261 37604.980238 7.369507e+07 Earth False \n", "1159 36720.077728 3.366114e+07 Earth False \n", "48095 7616.496535 6.376350e+06 Earth False \n", "90234 21894.554692 5.736984e+07 Earth False \n", "12013 49828.611609 4.305599e+07 Earth False \n", "\n", " absolute_magnitude hazardous \n", "9040 18.75 False \n", "67305 22.50 False \n", "77741 25.82 False \n", "81520 21.20 False \n", "508 24.60 False \n", "... ... ... \n", "28261 21.40 False \n", "1159 25.54 False \n", "48095 27.90 False \n", "90234 22.50 False \n", "12013 24.20 False \n", "\n", "[18168 rows x 10 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'y_test'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
hazardous
9040False
67305False
77741False
81520False
508False
......
28261False
1159False
48095False
90234False
12013False
\n", "

18168 rows × 1 columns

\n", "
" ], "text/plain": [ " hazardous\n", "9040 False\n", "67305 False\n", "77741 False\n", "81520 False\n", "508 False\n", "... ...\n", "28261 False\n", "1159 False\n", "48095 False\n", "90234 False\n", "12013 False\n", "\n", "[18168 rows x 1 columns]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from typing import Tuple\n", "import pandas as pd\n", "from pandas import DataFrame\n", "from sklearn.model_selection import train_test_split\n", "\n", "# Устанавливаем случайное состояние\n", "random_state = 42\n", "\n", "def split_stratified_into_train_val_test(\n", " df_input,\n", " stratify_colname=\"y\",\n", " frac_train=0.6,\n", " frac_val=0.15,\n", " frac_test=0.25,\n", " random_state=None,\n", ") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n", " \n", " if frac_train + frac_val + frac_test != 1.0:\n", " raise ValueError(\n", " \"fractions %f, %f, %f do not add up to 1.0\"\n", " % (frac_train, frac_val, frac_test)\n", " )\n", " if stratify_colname not in df_input.columns:\n", " raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n", " X = df_input # Contains all columns.\n", " y = df_input[\n", " [stratify_colname]\n", " ] # Dataframe of just the column on which to stratify.\n", " # Split original dataframe into train and temp dataframes.\n", " df_train, df_temp, y_train, y_temp = train_test_split(\n", " X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n", " )\n", " if frac_val <= 0:\n", " assert len(df_input) == len(df_train) + len(df_temp)\n", " return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n", " # Split the temp dataframe into val and test dataframes.\n", " relative_frac_test = frac_test / (frac_val + frac_test)\n", " df_val, df_test, y_val, y_test = train_test_split(\n", " df_temp,\n", " y_temp,\n", " stratify=y_temp,\n", " test_size=relative_frac_test,\n", " random_state=random_state,\n", " )\n", " assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n", " return df_train, df_val, df_test, y_train, y_val, y_test\n", "\n", "X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n", " df, stratify_colname=\"hazardous\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n", ")\n", "\n", "display(\"X_train\", X_train)\n", "display(\"y_train\", y_train)\n", "\n", "display(\"X_test\", X_test)\n", "display(\"y_test\", y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Формирование конвейера для классификации данных\n", "preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n", "\n", "preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n", "\n", "features_preprocessing -- трансформер для предобработки признаков\n", "\n", "features_engineering -- трансформер для конструирования признаков\n", "\n", "drop_columns -- трансформер для удаления колонок\n", "\n", "pipeline_end -- основной конвейер предобработки данных и конструирования признаков" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.base import BaseEstimator, TransformerMixin\n", "from sklearn.compose import ColumnTransformer\n", "from sklearn.discriminant_analysis import StandardScaler\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.preprocessing import OneHotEncoder\n", "\n", "class EarthObjectsFeatures(BaseEstimator, TransformerMixin):\n", " def __init__(self):\n", " pass\n", " def fit(self, X, y=None):\n", " return self\n", " def transform(self, X, y=None):\n", " X[\"Length_to_Width_Ratio\"] = X[\"x\"] / X[\"y\"]\n", " return X\n", " def get_feature_names_out(self, features_in):\n", " return np.append(features_in, [\"Length_to_Width_Ratio\"], axis=0)\n", " \n", "\n", "columns_to_drop = [\"name\", \"orbiting_body\"]\n", "num_columns = [\"est_diameter_min\", \"est_diameter_max\",\n", " \"relative_velocity\", \"miss_distance\", \"sentry_object\",\n", " \"absolute_magnitude\", \"hazardous\"]\n", "cat_columns = []\n", "\n", "num_imputer = SimpleImputer(strategy=\"median\")\n", "num_scaler = StandardScaler()\n", "preprocessing_num = Pipeline(\n", " [\n", " (\"imputer\", num_imputer),\n", " (\"scaler\", num_scaler),\n", " ]\n", ")\n", "\n", "cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n", "cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n", "preprocessing_cat = Pipeline(\n", " [\n", " (\"imputer\", cat_imputer),\n", " (\"encoder\", cat_encoder),\n", " ]\n", ")\n", "\n", "features_preprocessing = ColumnTransformer(\n", " verbose_feature_names_out=False,\n", " transformers=[\n", " (\"prepocessing_num\", preprocessing_num, num_columns),\n", " (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n", " ],\n", " remainder=\"passthrough\"\n", ")\n", "\n", "\n", "drop_columns = ColumnTransformer(\n", " verbose_feature_names_out=False,\n", " transformers=[\n", " (\"drop_columns\", \"drop\", columns_to_drop),\n", " ],\n", " remainder=\"passthrough\",\n", ")\n", "\n", "features_postprocessing = ColumnTransformer(\n", " verbose_feature_names_out=False,\n", " transformers=[\n", " (\"prepocessing_cat\", preprocessing_cat, [\"Cabin_type\"]),\n", " ],\n", " remainder=\"passthrough\",\n", ")\n", "\n", "pipeline_end = Pipeline(\n", " [\n", " (\"features_preprocessing\", features_preprocessing),\n", " (\"drop_columns\", drop_columns),\n", " ]\n", ")\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Демонстрация работы конвейера" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
est_diameter_minest_diameter_maxrelative_velocitymiss_distancesentry_objectabsolute_magnitudehazardousid
2639-0.331616-0.331616-0.1881600.4942970.00.577785-0.3283473634614
29138-0.312486-0.312486-1.0407290.8667160.00.412170-0.32834754143560
369270.2342460.2342462.4271341.2193990.0-1.009355-0.3283473836085
618550.1019600.1019601.2051480.8439630.0-0.836840-0.3283473769804
15916-0.386643-0.386643-1.1473550.0561450.01.509367-0.3283473824978
...........................
29491-0.400466-0.400466-1.1176940.0643010.02.233931-0.3283473827304
18373-0.079077-0.0790771.3778511.6127340.0-0.509061-0.3283473735468
25031-0.378159-0.378159-0.5468840.2454000.01.267846-0.3283473802041
354560.2963000.296300-1.128369-0.6961300.0-1.078361-0.3283473430406
143050.5424040.542404-1.098590-0.8674400.0-1.302631-0.3283473285300
\n", "

72668 rows × 8 columns

\n", "
" ], "text/plain": [ " est_diameter_min est_diameter_max relative_velocity miss_distance \\\n", "2639 -0.331616 -0.331616 -0.188160 0.494297 \n", "29138 -0.312486 -0.312486 -1.040729 0.866716 \n", "36927 0.234246 0.234246 2.427134 1.219399 \n", "61855 0.101960 0.101960 1.205148 0.843963 \n", "15916 -0.386643 -0.386643 -1.147355 0.056145 \n", "... ... ... ... ... \n", "29491 -0.400466 -0.400466 -1.117694 0.064301 \n", "18373 -0.079077 -0.079077 1.377851 1.612734 \n", "25031 -0.378159 -0.378159 -0.546884 0.245400 \n", "35456 0.296300 0.296300 -1.128369 -0.696130 \n", "14305 0.542404 0.542404 -1.098590 -0.867440 \n", "\n", " sentry_object absolute_magnitude hazardous id \n", "2639 0.0 0.577785 -0.328347 3634614 \n", "29138 0.0 0.412170 -0.328347 54143560 \n", "36927 0.0 -1.009355 -0.328347 3836085 \n", "61855 0.0 -0.836840 -0.328347 3769804 \n", "15916 0.0 1.509367 -0.328347 3824978 \n", "... ... ... ... ... \n", "29491 0.0 2.233931 -0.328347 3827304 \n", "18373 0.0 -0.509061 -0.328347 3735468 \n", "25031 0.0 1.267846 -0.328347 3802041 \n", "35456 0.0 -1.078361 -0.328347 3430406 \n", "14305 0.0 -1.302631 -0.328347 3285300 \n", "\n", "[72668 rows x 8 columns]" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preprocessing_result = pipeline_end.fit_transform(X_train)\n", "preprocessed_df = pd.DataFrame(\n", " preprocessing_result,\n", " columns=pipeline_end.get_feature_names_out(),\n", ")\n", "\n", "preprocessed_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Формирование набора моделей для классификации\n", " logistic -- логистическая регрессия\n", "\n", "ridge -- гребневая регрессия\n", "\n", "decision_tree -- дерево решений\n", "\n", "knn -- k-ближайших соседей\n", "\n", "naive_bayes -- наивный Байесовский классификатор\n", "\n", "gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n", "\n", "random_forest -- метод случайного леса (набор деревьев решений)\n", "\n", "mlp -- многослойный персептрон (нейронная сеть)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n", "\n", "class_models = {\n", " \"logistic\": {\"model\": linear_model.LogisticRegression()},\n", " # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n", " \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n", " \"decision_tree\": {\n", " \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n", " },\n", " \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n", " \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n", " \"gradient_boosting\": {\n", " \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n", " },\n", " \"random_forest\": {\n", " \"model\": ensemble.RandomForestClassifier(\n", " max_depth=11, class_weight=\"balanced\", random_state=random_state\n", " )\n", " },\n", " \"mlp\": {\n", " \"model\": neural_network.MLPClassifier(\n", " hidden_layer_sizes=(7,),\n", " max_iter=500,\n", " early_stopping=True,\n", " random_state=random_state,\n", " )\n", " },\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Обучение моделей на обучающем наборе данных и оценка на тестовом" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: logistic\n", "Model: ridge\n", "Model: decision_tree\n", "Model: knn\n", "Model: naive_bayes\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", "c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: gradient_boosting\n", "Model: random_forest\n", "Model: mlp\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", "c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" ] } ], "source": [ "import numpy as np\n", "from sklearn import metrics\n", "\n", "for model_name in class_models.keys():\n", " print(f\"Model: {model_name}\")\n", " model = class_models[model_name][\"model\"]\n", "\n", " model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n", " model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n", "\n", " y_train_predict = model_pipeline.predict(X_train)\n", " y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n", " y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n", "\n", " class_models[model_name][\"pipeline\"] = model_pipeline\n", " class_models[model_name][\"probs\"] = y_test_probs\n", " class_models[model_name][\"preds\"] = y_test_predict\n", "\n", " class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n", " y_train, y_train_predict\n", " )\n", " class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n", " y_test, y_test_predict\n", " )\n", " class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n", " y_train, y_train_predict\n", " )\n", " class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n", " y_test, y_test_predict\n", " )\n", " class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n", " y_train, y_train_predict\n", " )\n", " class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n", " y_test, y_test_predict\n", " )\n", " class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n", " y_test, y_test_probs\n", " )\n", " class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n", " class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n", " class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n", " y_test, y_test_predict\n", " )\n", " class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n", " y_test, y_test_predict\n", " )\n", " class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n", " y_test, y_test_predict\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Сводная таблица оценок качества для использованных моделей классификации" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import ConfusionMatrixDisplay\n", "import matplotlib.pyplot as plt\n", "\n", "_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n", "for index, key in enumerate(class_models.keys()):\n", " c_matrix = class_models[key][\"Confusion_matrix\"]\n", " disp = ConfusionMatrixDisplay(\n", " confusion_matrix=c_matrix, display_labels=[\"hazardous\", \"safe\"]\n", " ).plot(ax=ax.flat[index])\n", " disp.ax_.set_title(key)\n", "\n", "plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "16400 - количество истинных положительных диагнозов (True Positives), где модель правильно определила объекты как \"hazardous\".\n", "\n", "1768 в некоторых моделях - количество ложных отрицательных диагнозов (False Negatives), где модель неправильно определила объекты, которые на самом деле принадлежат к классу \"hazardous\", но были отнесены к классу \"safe\". \n", "\n", "Исходя из значений True Positives и False Negatives, можно сказать, что модель имеет высокую точность при предсказании класса \"hazardous\". В принципе, уровень ложных отрицательных результатов в некоторых моделях (1768) говорит нам о том, что существует некотрое небольшое количество примеров, которые модель пропускает.\n", "\n", "Точность, полнота, верность (аккуратность), F-мера" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 Precision_trainPrecision_testRecall_trainRecall_testAccuracy_trainAccuracy_testF1_trainF1_test
logistic1.0000001.0000001.0000001.0000001.0000001.0000001.0000001.000000
decision_tree1.0000001.0000001.0000001.0000001.0000001.0000001.0000001.000000
random_forest1.0000001.0000001.0000001.0000001.0000001.0000001.0000001.000000
gradient_boosting1.0000001.0000001.0000001.0000001.0000001.0000001.0000001.000000
knn0.8845960.8263740.7446270.6380090.9656930.9517280.8085990.720077
naive_bayes0.0000000.0000000.0000000.0000000.9026810.9026860.0000000.000000
mlp0.0000000.0000000.0000000.0000000.9026810.9026860.0000000.000000
ridge0.4157800.4212531.0000001.0000000.8632550.8663030.5873510.592791
\n" ], "text/plain": [ "" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n", " [\n", " \"Precision_train\",\n", " \"Precision_test\",\n", " \"Recall_train\",\n", " \"Recall_test\",\n", " \"Accuracy_train\",\n", " \"Accuracy_test\",\n", " \"F1_train\",\n", " \"F1_test\",\n", " ]\n", "]\n", "class_metrics.sort_values(\n", " by=\"Accuracy_test\", ascending=False\n", ").style.background_gradient(\n", " cmap=\"plasma\",\n", " low=0.3,\n", " high=1,\n", " subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n", ").background_gradient(\n", " cmap=\"viridis\",\n", " low=1,\n", " high=0.3,\n", " subset=[\n", " \"Precision_train\",\n", " \"Precision_test\",\n", " \"Recall_train\",\n", " \"Recall_test\",\n", " ],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Все модели в данной выборке, а именно логистическая регрессия, ридж-регрессия, дерево решений, KNN, наивный байесовский классификатор, градиентный бустинг, случайный лес и многослойный перцептрон (MLP) демонстрируют неплохие значения по всем метрикам на обучающих и тестовых наборах данных.\n", "\n", "Модели Naive Bayes и MLP не так эффективны по сравнению с другими, но в некоторых метриках показывают высокие результаты. \n", "ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 Accuracy_testF1_testROC_AUC_testCohen_kappa_testMCC_test
logistic1.0000001.0000001.0000001.0000001.000000
decision_tree1.0000001.0000001.0000001.0000001.000000
random_forest1.0000001.0000001.0000001.0000001.000000
gradient_boosting1.0000001.0000001.0000001.0000001.000000
ridge0.8663030.5927910.9956750.5281800.599051
knn0.9517280.7200770.9534050.6941410.701100
naive_bayes0.9026860.0000000.7663410.0000000.000000
mlp0.9026860.0000000.5000000.0000000.000000
\n" ], "text/plain": [ "" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n", " [\n", " \"Accuracy_test\",\n", " \"F1_test\",\n", " \"ROC_AUC_test\",\n", " \"Cohen_kappa_test\",\n", " \"MCC_test\",\n", " ]\n", "]\n", "class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n", " cmap=\"plasma\",\n", " low=0.3,\n", " high=1,\n", " subset=[\n", " \"ROC_AUC_test\",\n", " \"MCC_test\",\n", " \"Cohen_kappa_test\",\n", " ],\n", ").background_gradient(\n", " cmap=\"viridis\",\n", " low=1,\n", " high=0.3,\n", " subset=[\n", " \"Accuracy_test\",\n", " \"F1_test\",\n", " ],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Схожий вывод можно сделать и для следующих метрик: Accuracy, F1, ROC AUC, Cohen's Kappa и MCC. Все модели, кроме Naive Bayes и MLP, указывают на хорошо-развитую способность к выделению классов" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'logistic'" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n", "\n", "display(best_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Вывод данных с ошибкой предсказания для оценки" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Error items count: 0'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idPredictednameest_diameter_minest_diameter_maxrelative_velocitymiss_distanceorbiting_bodysentry_objectabsolute_magnitudehazardous
\n", "
" ], "text/plain": [ "Empty DataFrame\n", "Columns: [id, Predicted, name, est_diameter_min, est_diameter_max, relative_velocity, miss_distance, orbiting_body, sentry_object, absolute_magnitude, hazardous]\n", "Index: []" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preprocessing_result = pipeline_end.transform(X_test)\n", "preprocessed_df = pd.DataFrame(\n", " preprocessing_result,\n", " columns=pipeline_end.get_feature_names_out(),\n", ")\n", "\n", "y_pred = class_models[best_model][\"preds\"]\n", "\n", "error_index = y_test[y_test[\"hazardous\"] != y_pred].index.tolist()\n", "display(f\"Error items count: {len(error_index)}\")\n", "\n", "error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n", "error_df = X_test.loc[error_index].copy()\n", "error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n", "error_df.sort_index()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Пример использования обученной модели (конвейера) для предсказания\n" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idnameest_diameter_minest_diameter_maxrelative_velocitymiss_distanceorbiting_bodysentry_objectabsolute_magnitudehazardous
673053774018(2017 HF1)0.0840530.18794953291.01622668625911.198806EarthFalse22.5False
\n", "
" ], "text/plain": [ " id name est_diameter_min est_diameter_max \\\n", "67305 3774018 (2017 HF1) 0.084053 0.187949 \n", "\n", " relative_velocity miss_distance orbiting_body sentry_object \\\n", "67305 53291.016226 68625911.198806 Earth False \n", "\n", " absolute_magnitude hazardous \n", "67305 22.5 False " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
est_diameter_minest_diameter_maxrelative_velocitymiss_distancesentry_objectabsolute_magnitudehazardousid
67305-0.140818-0.1408180.2072581.4106530.0-0.353797-0.3283473774018.0
\n", "
" ], "text/plain": [ " est_diameter_min est_diameter_max relative_velocity miss_distance \\\n", "67305 -0.140818 -0.140818 0.207258 1.410653 \n", "\n", " sentry_object absolute_magnitude hazardous id \n", "67305 0.0 -0.353797 -0.328347 3774018.0 " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'predicted: False (proba: [9.99855425e-01 1.44575476e-04])'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'real: 0'" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = class_models[best_model][\"pipeline\"]\n", "\n", "example_id = 67305\n", "test = pd.DataFrame(X_test.loc[example_id, :]).T\n", "test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n", "display(test)\n", "display(test_preprocessed)\n", "result_proba = model.predict_proba(test)[0]\n", "result = model.predict(test)[0]\n", "real = int(y_test.loc[example_id].values[0])\n", "display(f\"predicted: {result} (proba: {result_proba})\")\n", "display(f\"real: {real}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Подбор гиперпараметров методом поиска по сетке " ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n", " _data = np.array(data, dtype=dtype, copy=copy,\n" ] }, { "data": { "text/plain": [ "{'model__criterion': 'gini',\n", " 'model__max_depth': 5,\n", " 'model__max_features': 'sqrt',\n", " 'model__n_estimators': 50}" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.model_selection import GridSearchCV\n", "\n", "optimized_model_type = \"random_forest\"\n", "\n", "random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n", "\n", "param_grid = {\n", " \"model__n_estimators\": [10, 50, 100],\n", " \"model__max_features\": [\"sqrt\", \"log2\"],\n", " \"model__max_depth\": [5, 7, 10],\n", " \"model__criterion\": [\"gini\", \"entropy\"],\n", "}\n", "\n", "gs_optomizer = GridSearchCV(\n", " estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n", ")\n", "gs_optomizer.fit(X_train, y_train.values.ravel())\n", "gs_optomizer.best_params_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Обучение модели с новыми гиперпараметрами" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "from sklearn.pipeline import Pipeline\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.compose import ColumnTransformer\n", "from sklearn.ensemble import RandomForestClassifier\n", "import numpy as np\n", "from sklearn import metrics\n", "import pandas as pd\n", "\n", "\n", "# Определяем числовые признаки\n", "numeric_features = X_train.select_dtypes(include=['float64', 'int64']).columns.tolist()\n", "\n", "# Установка random_state\n", "random_state = 42\n", "\n", "# Определение трансформера\n", "pipeline_end = ColumnTransformer([\n", " ('numeric', StandardScaler(), numeric_features),\n", " # Добавьте другие трансформеры, если требуется\n", "])\n", "\n", "# Объявление модели\n", "optimized_model = RandomForestClassifier(\n", " random_state=random_state,\n", " criterion=\"gini\",\n", " max_depth=5,\n", " max_features=\"sqrt\",\n", " n_estimators=50,\n", ")\n", "\n", "# Создание пайплайна с корректными шагами\n", "result = {}\n", "\n", "# Обучение модели\n", "result[\"pipeline\"] = Pipeline([\n", " (\"pipeline\", pipeline_end),\n", " (\"model\", optimized_model)\n", "]).fit(X_train, y_train.values.ravel())\n", "\n", "# Прогнозирование и расчет метрик\n", "result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n", "result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n", "result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n", "\n", "# Метрики для оценки модели\n", "result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n", "result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n", "result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n", "result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n", "result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n", "result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n", "result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n", "result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n", "result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n", "result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n", "result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n", "result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Формирование данных для оценки старой и новой версии модели" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n", "optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n", " data=class_models[optimized_model_type]\n", ")\n", "optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n", " data=result\n", ")\n", "optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n", "optimized_metrics = optimized_metrics.set_index(\"Name\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Оценка параметров старой и новой модели" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 Precision_trainPrecision_testRecall_trainRecall_testAccuracy_trainAccuracy_testF1_trainF1_test
Name        
Old1.0000001.0000001.0000001.0000001.0000001.0000001.0000001.000000
New0.8331910.8625000.1384330.1561090.9134560.9154560.2374200.264368
\n" ], "text/plain": [ "" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "optimized_metrics[\n", " [\n", " \"Precision_train\",\n", " \"Precision_test\",\n", " \"Recall_train\",\n", " \"Recall_test\",\n", " \"Accuracy_train\",\n", " \"Accuracy_test\",\n", " \"F1_train\",\n", " \"F1_test\",\n", " ]\n", "].style.background_gradient(\n", " cmap=\"plasma\",\n", " low=0.3,\n", " high=1,\n", " subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n", ").background_gradient(\n", " cmap=\"viridis\",\n", " low=1,\n", " high=0.3,\n", " subset=[\n", " \"Precision_train\",\n", " \"Precision_test\",\n", " \"Recall_train\",\n", " \"Recall_test\",\n", " ],\n", ")" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 Accuracy_testF1_testROC_AUC_testCohen_kappa_testMCC_test
Name     
Old1.0000001.0000001.0000001.0000001.000000
New0.9154560.2643680.9274930.2417510.345694
\n" ], "text/plain": [ "" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "optimized_metrics[\n", " [\n", " \"Accuracy_test\",\n", " \"F1_test\",\n", " \"ROC_AUC_test\",\n", " \"Cohen_kappa_test\",\n", " \"MCC_test\",\n", " ]\n", "].style.background_gradient(\n", " cmap=\"plasma\",\n", " low=0.3,\n", " high=1,\n", " subset=[\n", " \"ROC_AUC_test\",\n", " \"MCC_test\",\n", " \"Cohen_kappa_test\",\n", " ],\n", ").background_gradient(\n", " cmap=\"viridis\",\n", " low=1,\n", " high=0.3,\n", " subset=[\n", " \"Accuracy_test\",\n", " \"F1_test\",\n", " ],\n", ")" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n", ")\n", "\n", "for index in range(0, len(optimized_metrics)):\n", " c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n", " disp = ConfusionMatrixDisplay(\n", " confusion_matrix=c_matrix, display_labels=[\"hazardous\", \"safe\"]\n", " ).plot(ax=ax.flat[index])\n", "\n", "plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "В желтых квадрате мы наблюдаем значение 16400, что обозначает количество правильно классифицированных объектов, отнесенных к классу \"hazardsous\". Это свидетельствует о том, что модель успешно идентифицирует объекты этого класса, минимизируя количество ложных положительных срабатываний.\n", "\n", "В фиолетвом квадрате значение 276 указывает на количество правильно классифицированных объектов, отнесенных к классу \"More\". Это является показателем не такой высокой точности модели в определении объектов данного класса." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Определение достижимого уровня качества модели для второй задачи (задача регрессии)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Index(['id', 'name', 'est_diameter_min', 'est_diameter_max',\n", " 'relative_velocity', 'miss_distance', 'orbiting_body', 'sentry_object',\n", " 'absolute_magnitude', 'hazardous'],\n", " dtype='object')\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idnameest_diameter_minest_diameter_maxrelative_velocitymiss_distanceorbiting_bodysentry_objectabsolute_magnitudehazardous
02162635162635 (2000 SS164)1.1982712.67941513569.2492245.483974e+07EarthFalse16.73False
12277475277475 (2005 WK4)0.2658000.59434773588.7266636.143813e+07EarthFalse20.00True
22512244512244 (2015 YE18)0.7220301.614507114258.6921294.979872e+07EarthFalse17.83False
33596030(2012 BV13)0.0965060.21579424764.3031382.543497e+07EarthFalse22.20False
43667127(2014 GE35)0.2550090.57021742737.7337654.627557e+07EarthFalse20.09True
.................................
908313763337(2016 VX1)0.0265800.05943552078.8866921.230039e+07EarthFalse25.00False
908323837603(2019 AD3)0.0167710.03750146114.6050735.432121e+07EarthFalse26.00False
9083354017201(2020 JP3)0.0319560.0714567566.8077322.840077e+07EarthFalse24.60False
9083454115824(2021 CN5)0.0073210.01637069199.1544846.869206e+07EarthFalse27.80False
9083554205447(2021 TW7)0.0398620.08913327024.4555535.977213e+07EarthFalse24.12False
\n", "

90836 rows × 10 columns

\n", "
" ], "text/plain": [ " id name est_diameter_min est_diameter_max \\\n", "0 2162635 162635 (2000 SS164) 1.198271 2.679415 \n", "1 2277475 277475 (2005 WK4) 0.265800 0.594347 \n", "2 2512244 512244 (2015 YE18) 0.722030 1.614507 \n", "3 3596030 (2012 BV13) 0.096506 0.215794 \n", "4 3667127 (2014 GE35) 0.255009 0.570217 \n", "... ... ... ... ... \n", "90831 3763337 (2016 VX1) 0.026580 0.059435 \n", "90832 3837603 (2019 AD3) 0.016771 0.037501 \n", "90833 54017201 (2020 JP3) 0.031956 0.071456 \n", "90834 54115824 (2021 CN5) 0.007321 0.016370 \n", "90835 54205447 (2021 TW7) 0.039862 0.089133 \n", "\n", " relative_velocity miss_distance orbiting_body sentry_object \\\n", "0 13569.249224 5.483974e+07 Earth False \n", "1 73588.726663 6.143813e+07 Earth False \n", "2 114258.692129 4.979872e+07 Earth False \n", "3 24764.303138 2.543497e+07 Earth False \n", "4 42737.733765 4.627557e+07 Earth False \n", "... ... ... ... ... \n", "90831 52078.886692 1.230039e+07 Earth False \n", "90832 46114.605073 5.432121e+07 Earth False \n", "90833 7566.807732 2.840077e+07 Earth False \n", "90834 69199.154484 6.869206e+07 Earth False \n", "90835 27024.455553 5.977213e+07 Earth False \n", "\n", " absolute_magnitude hazardous \n", "0 16.73 False \n", "1 20.00 True \n", "2 17.83 False \n", "3 22.20 False \n", "4 20.09 True \n", "... ... ... \n", "90831 25.00 False \n", "90832 26.00 False \n", "90833 24.60 False \n", "90834 27.80 False \n", "90835 24.12 False \n", "\n", "[90836 rows x 10 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from sklearn.model_selection import train_test_split\n", "from sklearn import set_config\n", "\n", "random_state=42\n", "set_config(transform_output=\"pandas\")\n", "df = pd.read_csv(\".//static//csv//neo.csv\")\n", "print(df.columns)\n", "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи регрессии " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'X_train'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idnameest_diameter_minest_diameter_maxrelative_velocityorbiting_bodysentry_objectabsolute_magnitudehazardous
355383826685(2018 PR10)0.0384200.08590991103.489666EarthFalse24.20False
403932277830277830 (2006 HR29)0.1925550.43056628359.611312EarthFalse20.70False
585403638201(2013 HT25)0.0046190.010329107351.426865EarthFalse28.80False
616703836282(2018 WR)0.0152950.03420121423.536884EarthFalse26.20False
114353802002(2018 FU1)0.0116030.02594469856.053840EarthFalse26.80False
..............................
62652530151530151 (2011 AW55)0.2111320.47210688209.754856EarthFalse20.50False
548863831736(2018 TD5)0.0350390.07835058758.452153EarthFalse24.40False
768202512234512234 (2015 VO66)0.2111320.47210652355.509176EarthFalse20.50True
86054054466(2020 SG1)0.2821990.63101550527.379563EarthFalse19.87False
157953773929(2017 GL7)0.0752580.16828322527.647871EarthFalse22.74False
\n", "

72668 rows × 9 columns

\n", "
" ], "text/plain": [ " id name est_diameter_min est_diameter_max \\\n", "35538 3826685 (2018 PR10) 0.038420 0.085909 \n", "40393 2277830 277830 (2006 HR29) 0.192555 0.430566 \n", "58540 3638201 (2013 HT25) 0.004619 0.010329 \n", "61670 3836282 (2018 WR) 0.015295 0.034201 \n", "11435 3802002 (2018 FU1) 0.011603 0.025944 \n", "... ... ... ... ... \n", "6265 2530151 530151 (2011 AW55) 0.211132 0.472106 \n", "54886 3831736 (2018 TD5) 0.035039 0.078350 \n", "76820 2512234 512234 (2015 VO66) 0.211132 0.472106 \n", "860 54054466 (2020 SG1) 0.282199 0.631015 \n", "15795 3773929 (2017 GL7) 0.075258 0.168283 \n", "\n", " relative_velocity orbiting_body sentry_object absolute_magnitude \\\n", "35538 91103.489666 Earth False 24.20 \n", "40393 28359.611312 Earth False 20.70 \n", "58540 107351.426865 Earth False 28.80 \n", "61670 21423.536884 Earth False 26.20 \n", "11435 69856.053840 Earth False 26.80 \n", "... ... ... ... ... \n", "6265 88209.754856 Earth False 20.50 \n", "54886 58758.452153 Earth False 24.40 \n", "76820 52355.509176 Earth False 20.50 \n", "860 50527.379563 Earth False 19.87 \n", "15795 22527.647871 Earth False 22.74 \n", "\n", " hazardous \n", "35538 False \n", "40393 False \n", "58540 False \n", "61670 False \n", "11435 False \n", "... ... \n", "6265 False \n", "54886 False \n", "76820 True \n", "860 False \n", "15795 False \n", "\n", "[72668 rows x 9 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'y_train'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
miss_distance
355386.350550e+07
403932.868167e+07
585405.388098e+04
616705.103884e+07
114357.360836e+07
......
62654.034289e+07
548864.389994e+06
768204.380532e+07
8605.837007e+07
157952.281469e+07
\n", "

72668 rows × 1 columns

\n", "
" ], "text/plain": [ " miss_distance\n", "35538 6.350550e+07\n", "40393 2.868167e+07\n", "58540 5.388098e+04\n", "61670 5.103884e+07\n", "11435 7.360836e+07\n", "... ...\n", "6265 4.034289e+07\n", "54886 4.389994e+06\n", "76820 4.380532e+07\n", "860 5.837007e+07\n", "15795 2.281469e+07\n", "\n", "[72668 rows x 1 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'X_test'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idnameest_diameter_minest_diameter_maxrelative_velocityorbiting_bodysentry_objectabsolute_magnitudehazardous
204063943344(2019 YT1)0.0242410.05420522148.962596EarthFalse25.20False
744433879239(2019 US)0.0127220.02844726477.211836EarthFalse26.60False
743063879244(2019 UU)0.0133220.02978833770.201397EarthFalse26.50False
459432481965481965 (2009 EB1)0.1934440.43255443599.575296EarthFalse20.69False
628593789471(2017 WJ1)0.0441120.09863736398.080883EarthFalse23.90False
..............................
516343694131(2014 UF56)0.0088010.01968157414.305699EarthFalse27.40False
8508354235475(2022 AG1)0.0249200.05572450882.935767EarthFalse25.14False
389053775176(2017 LD)0.0084050.01879524954.754212EarthFalse27.50False
161442434734434734 (2006 FX)0.2658000.59434757455.404666EarthFalse20.00True
545083170208(2003 YG136)0.0231500.05176572602.093427EarthFalse25.30False
\n", "

18168 rows × 9 columns

\n", "
" ], "text/plain": [ " id name est_diameter_min est_diameter_max \\\n", "20406 3943344 (2019 YT1) 0.024241 0.054205 \n", "74443 3879239 (2019 US) 0.012722 0.028447 \n", "74306 3879244 (2019 UU) 0.013322 0.029788 \n", "45943 2481965 481965 (2009 EB1) 0.193444 0.432554 \n", "62859 3789471 (2017 WJ1) 0.044112 0.098637 \n", "... ... ... ... ... \n", "51634 3694131 (2014 UF56) 0.008801 0.019681 \n", "85083 54235475 (2022 AG1) 0.024920 0.055724 \n", "38905 3775176 (2017 LD) 0.008405 0.018795 \n", "16144 2434734 434734 (2006 FX) 0.265800 0.594347 \n", "54508 3170208 (2003 YG136) 0.023150 0.051765 \n", "\n", " relative_velocity orbiting_body sentry_object absolute_magnitude \\\n", "20406 22148.962596 Earth False 25.20 \n", "74443 26477.211836 Earth False 26.60 \n", "74306 33770.201397 Earth False 26.50 \n", "45943 43599.575296 Earth False 20.69 \n", "62859 36398.080883 Earth False 23.90 \n", "... ... ... ... ... \n", "51634 57414.305699 Earth False 27.40 \n", "85083 50882.935767 Earth False 25.14 \n", "38905 24954.754212 Earth False 27.50 \n", "16144 57455.404666 Earth False 20.00 \n", "54508 72602.093427 Earth False 25.30 \n", "\n", " hazardous \n", "20406 False \n", "74443 False \n", "74306 False \n", "45943 False \n", "62859 False \n", "... ... \n", "51634 False \n", "85083 False \n", "38905 False \n", "16144 True \n", "54508 False \n", "\n", "[18168 rows x 9 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'y_test'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
miss_distance
204065.028574e+07
744431.683201e+06
743063.943220e+06
459437.346837e+07
628596.352916e+07
......
516341.987273e+07
850833.119646e+07
389051.111942e+07
161448.501684e+06
545084.624727e+07
\n", "

18168 rows × 1 columns

\n", "
" ], "text/plain": [ " miss_distance\n", "20406 5.028574e+07\n", "74443 1.683201e+06\n", "74306 3.943220e+06\n", "45943 7.346837e+07\n", "62859 6.352916e+07\n", "... ...\n", "51634 1.987273e+07\n", "85083 3.119646e+07\n", "38905 1.111942e+07\n", "16144 8.501684e+06\n", "54508 4.624727e+07\n", "\n", "[18168 rows x 1 columns]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from typing import Tuple\n", "import pandas as pd\n", "from pandas import DataFrame\n", "from sklearn.model_selection import train_test_split\n", "\n", "def split_into_train_test(\n", " df_input: DataFrame,\n", " target_colname: str = \"miss_distance\",\n", " frac_train: float = 0.8,\n", " random_state: int = None,\n", ") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n", " \n", " if not (0 < frac_train < 1):\n", " raise ValueError(\"Fraction must be between 0 and 1.\")\n", " \n", " # Проверка наличия целевого признака\n", " if target_colname not in df_input.columns:\n", " raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n", " \n", " # Разделяем данные на признаки и целевую переменную\n", " X = df_input.drop(columns=[target_colname]) # Признаки\n", " y = df_input[[target_colname]] # Целевая переменная\n", "\n", " # Разделяем данные на обучающую и тестовую выборки\n", " X_train, X_test, y_train, y_test = train_test_split(\n", " X, y,\n", " test_size=(1.0 - frac_train),\n", " random_state=random_state\n", " )\n", " \n", " return X_train, X_test, y_train, y_test\n", "\n", "# Применение функции для разделения данных\n", "X_train, X_test, y_train, y_test = split_into_train_test(\n", " df, \n", " target_colname=\"miss_distance\", \n", " frac_train=0.8, \n", " random_state=42\n", ")\n", "\n", "# Для отображения результатов\n", "display(\"X_train\", X_train)\n", "display(\"y_train\", y_train)\n", "\n", "display(\"X_test\", X_test)\n", "display(\"y_test\", y_test)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Формирование конвейера для решения задачи регрессии" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.base import BaseEstimator, TransformerMixin\n", "from sklearn.compose import ColumnTransformer\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.preprocessing import OneHotEncoder\n", "from sklearn.ensemble import RandomForestRegressor # Пример регрессионной модели\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.pipeline import make_pipeline\n", "\n", "class EarthObjectsFeatures(BaseEstimator, TransformerMixin):\n", " def __init__(self):\n", " pass\n", " \n", " def fit(self, X, y=None):\n", " return self\n", "\n", " def transform(self, X, y=None):\n", " X[\"Length_to_Width_Ratio\"] = X[\"x\"] / X[\"y\"]\n", " return X\n", "\n", " def get_feature_names_out(self, features_in):\n", " return np.append(features_in, [\"Length_to_Width_Ratio\"], axis=0)\n", "\n", "# Указываем столбцы, которые нужно удалить и обрабатывать\n", "columns_to_drop = [\"name\", \"orbiting_body\"]\n", "num_columns = [\"est_diameter_min\", \"est_diameter_max\",\n", " \"relative_velocity\", \"sentry_object\",\n", " \"absolute_magnitude\", \"hazardous\"]\n", "cat_columns = [] \n", "\n", "# Определяем предобработку для численных данных\n", "num_imputer = SimpleImputer(strategy=\"median\")\n", "num_scaler = StandardScaler()\n", "preprocessing_num = Pipeline(\n", " [\n", " (\"imputer\", num_imputer),\n", " (\"scaler\", num_scaler),\n", " ]\n", ")\n", "\n", "# Определяем предобработку для категориальных данных\n", "cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n", "cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n", "preprocessing_cat = Pipeline(\n", " [\n", " (\"imputer\", cat_imputer),\n", " (\"encoder\", cat_encoder),\n", " ]\n", ")\n", "\n", "# Подготовка признаков с использованием ColumnTransformer\n", "features_preprocessing = ColumnTransformer(\n", " verbose_feature_names_out=False,\n", " transformers=[\n", " (\"preprocessing_num\", preprocessing_num, num_columns),\n", " (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n", " ],\n", " remainder=\"passthrough\"\n", ")\n", "\n", "# Удаление нежелательных столбцов\n", "drop_columns = ColumnTransformer(\n", " verbose_feature_names_out=False,\n", " transformers=[\n", " (\"drop_columns\", \"drop\", columns_to_drop),\n", " ],\n", " remainder=\"passthrough\",\n", ")\n", "\n", "# Постобработка признаков\n", "features_postprocessing = ColumnTransformer(\n", " verbose_feature_names_out=False,\n", " transformers=[\n", " (\"preprocessing_cat\", preprocessing_cat, [\"Cabin_type\"]), \n", " ],\n", " remainder=\"passthrough\",\n", ")\n", "\n", "# Создание окончательного конвейера\n", "pipeline = Pipeline(\n", " [\n", " (\"features_preprocessing\", features_preprocessing),\n", " (\"drop_columns\", drop_columns),\n", " (\"model\", RandomForestRegressor()) # Выбор модели для обучения\n", " ]\n", ")\n", "\n", "# Использование конвейера\n", "def train_pipeline(X, y):\n", " pipeline.fit(X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Формирование набора моделей для регрессии" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from sklearn.pipeline import make_pipeline\n", "from sklearn.preprocessing import PolynomialFeatures\n", "from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n", "\n", "random_state = 9\n", "\n", "models = {\n", " \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n", " \"linear_poly\": {\n", " \"model\": make_pipeline(\n", " PolynomialFeatures(degree=2),\n", " linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n", " )\n", " },\n", " \"linear_interact\": {\n", " \"model\": make_pipeline(\n", " PolynomialFeatures(interaction_only=True),\n", " linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n", " )\n", " },\n", " \"ridge\": {\"model\": linear_model.RidgeCV()},\n", " \"decision_tree\": {\n", " \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n", " },\n", " \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n", " \"random_forest\": {\n", " \"model\": ensemble.RandomForestRegressor(\n", " max_depth=7, random_state=random_state, n_jobs=-1\n", " )\n", " },\n", " \"mlp\": {\n", " \"model\": neural_network.MLPRegressor(\n", " activation=\"tanh\",\n", " hidden_layer_sizes=(3,),\n", " max_iter=500,\n", " early_stopping=True,\n", " random_state=random_state,\n", " )\n", " },\n", "}" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: linear\n" ] }, { "ename": "ValueError", "evalue": "could not convert string to float: '(2018 PR10)'", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", "Cell \u001b[1;32mIn[9], line 8\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m model_name \u001b[38;5;129;01min\u001b[39;00m models\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[0;32m 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModel: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m----> 8\u001b[0m fitted_model \u001b[38;5;241m=\u001b[39m \u001b[43mmodels\u001b[49m\u001b[43m[\u001b[49m\u001b[43mmodel_name\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodel\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mX_train\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mravel\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 11\u001b[0m y_train_pred \u001b[38;5;241m=\u001b[39m fitted_model\u001b[38;5;241m.\u001b[39mpredict(X_train\u001b[38;5;241m.\u001b[39mvalues)\n\u001b[0;32m 12\u001b[0m y_test_pred \u001b[38;5;241m=\u001b[39m fitted_model\u001b[38;5;241m.\u001b[39mpredict(X_test\u001b[38;5;241m.\u001b[39mvalues)\n", "File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473\u001b[0m, in \u001b[0;36m_fit_context..decorator..wrapper\u001b[1;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1466\u001b[0m estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[0;32m 1468\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[0;32m 1469\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[0;32m 1470\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[0;32m 1471\u001b[0m )\n\u001b[0;32m 1472\u001b[0m ):\n\u001b[1;32m-> 1473\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\linear_model\\_base.py:609\u001b[0m, in \u001b[0;36mLinearRegression.fit\u001b[1;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[0;32m 605\u001b[0m n_jobs_ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_jobs\n\u001b[0;32m 607\u001b[0m accept_sparse \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpositive \u001b[38;5;28;01melse\u001b[39;00m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcsr\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcsc\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcoo\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m--> 609\u001b[0m X, y \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_data\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 610\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 611\u001b[0m \u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 612\u001b[0m \u001b[43m \u001b[49m\u001b[43maccept_sparse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maccept_sparse\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 613\u001b[0m \u001b[43m \u001b[49m\u001b[43my_numeric\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 614\u001b[0m \u001b[43m \u001b[49m\u001b[43mmulti_output\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 615\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_writeable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 616\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 618\u001b[0m has_sw \u001b[38;5;241m=\u001b[39m sample_weight \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m 619\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_sw:\n", "File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:650\u001b[0m, in \u001b[0;36mBaseEstimator._validate_data\u001b[1;34m(self, X, y, reset, validate_separately, cast_to_ndarray, **check_params)\u001b[0m\n\u001b[0;32m 648\u001b[0m y \u001b[38;5;241m=\u001b[39m check_array(y, input_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mcheck_y_params)\n\u001b[0;32m 649\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 650\u001b[0m X, y \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_X_y\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mcheck_params\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 651\u001b[0m out \u001b[38;5;241m=\u001b[39m X, y\n\u001b[0;32m 653\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m no_val_X \u001b[38;5;129;01mand\u001b[39;00m check_params\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mensure_2d\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m):\n", "File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1301\u001b[0m, in \u001b[0;36mcheck_X_y\u001b[1;34m(X, y, accept_sparse, accept_large_sparse, dtype, order, copy, force_writeable, force_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, estimator)\u001b[0m\n\u001b[0;32m 1296\u001b[0m estimator_name \u001b[38;5;241m=\u001b[39m _check_estimator_name(estimator)\n\u001b[0;32m 1297\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 1298\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mestimator_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m requires y to be passed, but the target y is None\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1299\u001b[0m )\n\u001b[1;32m-> 1301\u001b[0m X \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_array\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1302\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1303\u001b[0m \u001b[43m \u001b[49m\u001b[43maccept_sparse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maccept_sparse\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1304\u001b[0m \u001b[43m \u001b[49m\u001b[43maccept_large_sparse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maccept_large_sparse\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1305\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1306\u001b[0m \u001b[43m \u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1307\u001b[0m \u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1308\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_writeable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_writeable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1309\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_all_finite\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_all_finite\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1310\u001b[0m \u001b[43m \u001b[49m\u001b[43mensure_2d\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensure_2d\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1311\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_nd\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mallow_nd\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1312\u001b[0m \u001b[43m \u001b[49m\u001b[43mensure_min_samples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensure_min_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1313\u001b[0m \u001b[43m \u001b[49m\u001b[43mensure_min_features\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensure_min_features\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1314\u001b[0m \u001b[43m \u001b[49m\u001b[43mestimator\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1315\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mX\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1316\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1318\u001b[0m y \u001b[38;5;241m=\u001b[39m _check_y(y, multi_output\u001b[38;5;241m=\u001b[39mmulti_output, y_numeric\u001b[38;5;241m=\u001b[39my_numeric, estimator\u001b[38;5;241m=\u001b[39mestimator)\n\u001b[0;32m 1320\u001b[0m check_consistent_length(X, y)\n", "File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1012\u001b[0m, in \u001b[0;36mcheck_array\u001b[1;34m(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_writeable, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)\u001b[0m\n\u001b[0;32m 1010\u001b[0m array \u001b[38;5;241m=\u001b[39m xp\u001b[38;5;241m.\u001b[39mastype(array, dtype, copy\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m 1011\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1012\u001b[0m array \u001b[38;5;241m=\u001b[39m \u001b[43m_asarray_with_order\u001b[49m\u001b[43m(\u001b[49m\u001b[43marray\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mxp\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mxp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1013\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ComplexWarning \u001b[38;5;28;01mas\u001b[39;00m complex_warning:\n\u001b[0;32m 1014\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 1015\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mComplex data not supported\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(array)\n\u001b[0;32m 1016\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mcomplex_warning\u001b[39;00m\n", "File \u001b[1;32mc:\\Users\\Admin\\StudioProjects\\AIM-PIbd-31-Alekseev-I-S\\aimenv\\Lib\\site-packages\\sklearn\\utils\\_array_api.py:745\u001b[0m, in \u001b[0;36m_asarray_with_order\u001b[1;34m(array, dtype, order, copy, xp, device)\u001b[0m\n\u001b[0;32m 743\u001b[0m array \u001b[38;5;241m=\u001b[39m numpy\u001b[38;5;241m.\u001b[39marray(array, order\u001b[38;5;241m=\u001b[39morder, dtype\u001b[38;5;241m=\u001b[39mdtype)\n\u001b[0;32m 744\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 745\u001b[0m array \u001b[38;5;241m=\u001b[39m \u001b[43mnumpy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43masarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43marray\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 747\u001b[0m \u001b[38;5;66;03m# At this point array is a NumPy ndarray. We convert it to an array\u001b[39;00m\n\u001b[0;32m 748\u001b[0m \u001b[38;5;66;03m# container that is consistent with the input's namespace.\u001b[39;00m\n\u001b[0;32m 749\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m xp\u001b[38;5;241m.\u001b[39masarray(array)\n", "\u001b[1;31mValueError\u001b[0m: could not convert string to float: '(2018 PR10)'" ] } ], "source": [ "import math\n", "from pandas import DataFrame\n", "from sklearn import metrics\n", "\n", "for model_name in models.keys():\n", " print(f\"Model: {model_name}\")\n", "\n", " fitted_model = models[model_name][\"model\"].fit(\n", " X_train.values, y_train.values.ravel()\n", " )\n", " y_train_pred = fitted_model.predict(X_train.values)\n", " y_test_pred = fitted_model.predict(X_test.values)\n", " models[model_name][\"fitted\"] = fitted_model\n", " models[model_name][\"train_preds\"] = y_train_pred\n", " models[model_name][\"preds\"] = y_test_pred\n", " models[model_name][\"RMSE_train\"] = math.sqrt(\n", " metrics.mean_squared_error(y_train, y_train_pred)\n", " )\n", " models[model_name][\"RMSE_test\"] = math.sqrt(\n", " metrics.mean_squared_error(y_test, y_test_pred)\n", " )\n", " models[model_name][\"RMAE_test\"] = math.sqrt(\n", " metrics.mean_absolute_error(y_test, y_test_pred)\n", " )\n", " models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)" ] } ], "metadata": { "kernelspec": { "display_name": "aimenv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.6" } }, "nbformat": 4, "nbformat_minor": 2 }