2024-12-08 22:49:06 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Лабораторная работа №4\n",
"\n",
2024-12-13 23:27:39 +04:00
"*Вариант задания:* 24 - Н а б о р данных \"Наблюдения НЛО в США\"\n",
"city_latitude"
2024-12-08 22:49:06 +04:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Выбор бизнес-целей \n",
"Для датасета недвижимости предлагаются две бизнес-цели:\n",
"\n",
"### Задача классификации:\n",
2024-12-13 23:27:39 +04:00
"Предсказание вероятности появления НЛО в зависимости от региона\n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
"Описание цели: Классифицировать широты на категории (например, высокая, средняя или низкая вероятность наблюдений НЛО).\n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
"Применение: Помощь исследовательским группам в приоритизации областей для исследований.\n",
2024-12-08 22:49:06 +04:00
"\n",
"### Задача регрессии:\n",
2024-12-13 23:27:39 +04:00
"Описание цели: Предсказать, сколько наблюдений НЛО можно ожидать в определенной широтной зоне(city_latitude).\n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
"Применение: Оптимизация ресурсов для мониторинга"
2024-12-08 22:49:06 +04:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Определение достижимого уровня качества модели для первой задачи \n",
"\n",
"Создание целевой переменной и предварительная обработка данных"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 3,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['summary', 'city', 'state', 'date_time', 'shape', 'duration', 'stats',\n",
" 'report_link', 'text', 'posted', 'city_latitude', 'city_longitude'],\n",
" dtype='object')\n"
]
}
],
"source": [
"import pandas as pd\n",
"from sklearn import set_config\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"from sklearn import linear_model, tree, neighbors, naive_bayes, ensemble, neural_network\n",
"from sklearn import metrics\n",
"import numpy as np\n",
"import warnings\n",
"#warnings.filterwarnings(\"ignore\", state=UserWarning)\n",
2024-12-13 23:27:39 +04:00
"df = pd.read_csv(\"../../datasets/nuforc_reports.csv\")\n",
"df = df.head(15000)\n",
2024-12-08 22:49:06 +04:00
"print(df.columns)"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 4,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"Среднее значение поля city_latitude: 39.143131819517855\n",
2024-12-08 22:49:06 +04:00
" summary city state \\\n",
"0 Viewed some red lights in the sky appearing to... Visalia CA \n",
"1 Look like 1 or 3 crafts from North traveling s... Cincinnati OH \n",
"2 seen dark rectangle moving slowly thru the sky... Tecopa CA \n",
"3 One red light moving switly west to east, beco... Knoxville TN \n",
"4 Bright, circular Fresnel-lens shaped light sev... Alexandria VA \n",
"\n",
" date_time shape duration \\\n",
"0 2021-12-15T21:45:00 light 2 minutes \n",
"1 2021-12-16T09:45:00 triangle 14 seconds \n",
"2 2021-12-10T00:00:00 rectangle Several minutes \n",
"3 2021-12-10T19:30:00 triangle 20-30 seconds \n",
"4 2021-12-07T08:00:00 circle NaN \n",
"\n",
" stats \\\n",
"0 Occurred : 12/15/2021 21:45 (Entered as : 12/... \n",
"1 Occurred : 12/16/2021 09:45 (Entered as : 12/... \n",
"2 Occurred : 12/10/2021 00:00 (Entered as : 12/... \n",
"3 Occurred : 12/10/2021 19:30 (Entered as : 12/... \n",
"4 Occurred : 12/7/2021 08:00 (Entered as : 12/0... \n",
"\n",
" report_link \\\n",
"0 http://www.nuforc.org/webreports/165/S165881.html \n",
"1 http://www.nuforc.org/webreports/165/S165888.html \n",
"2 http://www.nuforc.org/webreports/165/S165810.html \n",
"3 http://www.nuforc.org/webreports/165/S165825.html \n",
"4 http://www.nuforc.org/webreports/165/S165754.html \n",
"\n",
" text posted \\\n",
"0 Viewed some red lights in the sky appearing to... 2021-12-19T00:00:00 \n",
"1 Look like 1 or 3 crafts from North traveling s... 2021-12-19T00:00:00 \n",
"2 seen dark rectangle moving slowly thru the sky... 2021-12-19T00:00:00 \n",
"3 One red light moving switly west to east, beco... 2021-12-19T00:00:00 \n",
"4 Bright, circular Fresnel-lens shaped light sev... 2021-12-19T00:00:00 \n",
"\n",
" city_latitude city_longitude above_average_city_latitude \n",
"0 36.356650 -119.347937 0 \n",
2024-12-13 23:27:39 +04:00
"1 39.174503 -84.481363 1 \n",
2024-12-08 22:49:06 +04:00
"2 NaN NaN 0 \n",
"3 35.961561 -83.980115 0 \n",
"4 38.798958 -77.095133 0 \n"
]
}
],
"source": [
"# Установим параметры для вывода\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"# Рассчитываем среднее значение цены\n",
"average_city_latitude = df['city_latitude'].mean()\n",
"print(f\"Среднее значение поля city_latitude: {average_city_latitude}\")\n",
"\n",
"# Создаем новую переменную, указывающую, превышает ли цена среднюю цену\n",
"df['above_average_city_latitude'] = (df['city_latitude'] > average_city_latitude).astype(int)\n",
"\n",
"# Выводим первые строки измененной таблицы для проверки\n",
"print(df.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
"\n",
"Целевой признак -- above_average_city_latitude"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 5,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"X_train shape: (12000, 7)\n",
"y_train shape: (12000,)\n",
"X_test shape: (3000, 7)\n",
"y_test shape: (3000,)\n",
2024-12-08 22:49:06 +04:00
"X_train:\n",
2024-12-13 23:27:39 +04:00
" city state date_time shape \\\n",
"2096 Show Low AZ 2020-05-23T03:56:00 NaN \n",
"14791 Langley BC 1997-07-01T10:30:00 fireball \n",
"10478 Lake Whitney TX 1999-08-21T03:00:00 triangle \n",
"11595 Darlington LA 1988-09-15T03:00:00 unknown \n",
"13165 Dallas AR 1999-12-26T19:30:00 fireball \n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
" text city_latitude \\\n",
"2096 MADAR Node 74 34.234662 \n",
"14791 The object appeared to be a extemly radiant sp... NaN \n",
"10478 As My wife and I were delivering Newspapers th... NaN \n",
"11595 Blade sound passes overhead while photographin... NaN \n",
"13165 At ten degrees above the horizon, above treeli... NaN \n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
" city_longitude \n",
"2096 -110.075197 \n",
"14791 NaN \n",
"10478 NaN \n",
"11595 NaN \n",
"13165 NaN \n",
2024-12-08 22:49:06 +04:00
"y_train:\n",
2024-12-13 23:27:39 +04:00
" 2096 0\n",
"14791 0\n",
"10478 0\n",
"11595 0\n",
"13165 0\n",
2024-12-08 22:49:06 +04:00
"Name: above_average_city_latitude, dtype: int64\n",
"X_test:\n",
2024-12-13 23:27:39 +04:00
" city state date_time shape \\\n",
"1651 Goa (India) NaN 2020-02-18T19:00:00 flash \n",
"7303 Lansing MI 1997-06-27T02:40:00 light \n",
"14354 San Diego CA 1996-04-05T21:00:00 formation \n",
"14057 Belton MO 1999-11-23T18:10:00 oval \n",
"6461 Phoenix AZ 1995-05-21T19:30:00 NaN \n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
" text city_latitude \\\n",
"1651 I was there at Baga beach in Goa. That time a... NaN \n",
"7303 Summary : I was outside letting my cat in the ... 42.743800 \n",
"14354 Summary : Two military members sighted two red... 32.787229 \n",
"14057 Three bright lights in the front and sides and... 38.784300 \n",
"6461 Man southbound on I-17 100 miles N of Phoenix ... 33.535381 \n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
" city_longitude \n",
"1651 NaN \n",
"7303 -84.576708 \n",
"14354 -117.140268 \n",
"14057 -94.545700 \n",
"6461 -112.049460 \n",
2024-12-08 22:49:06 +04:00
"y_test:\n",
2024-12-13 23:27:39 +04:00
" 1651 0\n",
"7303 1\n",
"14354 0\n",
"14057 0\n",
"6461 0\n",
2024-12-08 22:49:06 +04:00
"Name: above_average_city_latitude, dtype: int64\n"
]
}
],
"source": [
"# Разделение набора данных на обучающую и тестовую выборки (80/20)\n",
"random_state = 42\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" df.drop(columns=['above_average_city_latitude', 'summary', 'stats', 'report_link', 'posted', \"duration\"]), # Исключаем столбец 'items'\n",
" df['above_average_city_latitude'],\n",
" stratify=df['above_average_city_latitude'],\n",
" test_size=0.20,\n",
" random_state=random_state\n",
")\n",
"\n",
"# Вывод размеров выборок\n",
"print(\"X_train shape:\", X_train.shape)\n",
"print(\"y_train shape:\", y_train.shape)\n",
"print(\"X_test shape:\", X_test.shape)\n",
"print(\"y_test shape:\", y_test.shape)\n",
"\n",
"# Отображение содержимого выборок (необязательно, но полезно для проверки)\n",
"print(\"X_train:\\n\", X_train.head())\n",
"print(\"y_train:\\n\", y_train.head())\n",
"print(\"X_test:\\n\", X_test.head())\n",
"print(\"y_test:\\n\", y_test.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование конвейера для классификации данных\n",
"\n",
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
"\n",
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
"\n",
"features_preprocessing -- трансформер для предобработки признаков\n",
"\n",
"drop_columns -- трансформер для удаления колонок\n",
"\n",
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 6,
2024-12-08 22:49:06 +04:00
"metadata": {},
2024-12-13 23:27:39 +04:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
}
],
2024-12-08 22:49:06 +04:00
"source": [
"# Определение столбцов для обработки\n",
"columns_to_drop = [\"date_time\", \"posted\", \"city\", \"state\", \"summary\", \"stats\", \"report_link\", \"duration\", \"text\"] # Столбцы, которые можно удалить\n",
"# ,\n",
"num_columns = [\"city_latitude\", \"city_longitude\"] # Числовые столбцы\n",
"cat_columns = [\"shape\"] # Категориальные столбцы\n",
"\n",
"# Проверка наличия столбцов перед удалением\n",
"columns_to_drop = [col for col in columns_to_drop if col in X_train.columns]\n",
"\n",
"# Препроцессинг числовых столбцов\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"# Препроцессинг категориальных столбцов\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"# Объединение препроцессинга\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"# Удаление ненужных столбцов\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"# Создание финального пайплайна\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")\n",
"\n",
"# Обучение пайплайна на обучающих данных\n",
"pipeline_end.fit(X_train)\n",
"\n",
"# Преобразование тестовых данных с использованием обученного пайплайна\n",
"X_test_transformed = pipeline_end.transform(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__Де мо нс тр а ция работы ко нве йе р а __"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 7,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
" city_latitude city_longitude shape_changing shape_chevron \\\n",
"2096 -0.967749 -0.790042 0.0 0.0 \n",
"14791 0.080860 0.208199 0.0 0.0 \n",
"10478 0.080860 0.208199 0.0 0.0 \n",
"11595 0.080860 0.208199 0.0 0.0 \n",
"13165 0.080860 0.208199 0.0 0.0 \n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
" shape_cigar shape_circle shape_cone shape_crescent shape_cross \\\n",
"2096 0.0 0.0 0.0 0.0 0.0 \n",
"14791 0.0 0.0 0.0 0.0 0.0 \n",
"10478 0.0 0.0 0.0 0.0 0.0 \n",
"11595 0.0 0.0 0.0 0.0 0.0 \n",
"13165 0.0 0.0 0.0 0.0 0.0 \n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
" shape_cylinder ... shape_light shape_other shape_oval \\\n",
"2096 0.0 ... 0.0 0.0 0.0 \n",
"14791 0.0 ... 0.0 0.0 0.0 \n",
"10478 0.0 ... 0.0 0.0 0.0 \n",
"11595 0.0 ... 0.0 0.0 0.0 \n",
"13165 0.0 ... 0.0 0.0 0.0 \n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
" shape_pyramid shape_rectangle shape_round shape_sphere \\\n",
"2096 0.0 0.0 0.0 0.0 \n",
"14791 0.0 0.0 0.0 0.0 \n",
"10478 0.0 0.0 0.0 0.0 \n",
"11595 0.0 0.0 0.0 0.0 \n",
"13165 0.0 0.0 0.0 0.0 \n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
" shape_teardrop shape_triangle shape_unknown \n",
"2096 0.0 0.0 1.0 \n",
"14791 0.0 0.0 0.0 \n",
"10478 0.0 1.0 0.0 \n",
"11595 0.0 0.0 1.0 \n",
"13165 0.0 0.0 0.0 \n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
"[5 rows x 28 columns]\n"
2024-12-08 22:49:06 +04:00
]
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"# Вывод первых строк обработанных данных\n",
"print(preprocessed_df.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование набора моделей для классификации\n",
"\n",
"logistic -- логистическая регрессия\n",
"\n",
"ridge -- гребневая регрессия\n",
"\n",
"decision_tree -- дерево решений\n",
"\n",
"knn -- k-ближайших соседей\n",
"\n",
"naive_bayes -- наивный Байесовский классификатор\n",
"\n",
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
"\n",
"random_forest -- метод случайного леса (набор деревьев решений)\n",
"\n",
"mlp -- многослойный персептрон (нейронная сеть)"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 8,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [],
"source": [
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=42)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=42\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=42,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение моделей и оценка их качества"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 9,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
2024-12-13 23:27:39 +04:00
"Model: ridge\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-12-08 22:49:06 +04:00
"Model: decision_tree\n",
2024-12-13 23:27:39 +04:00
"Model: knn\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\joblib\\externals\\loky\\backend\\context.py:136: UserWarning: Could not find the number of physical cores for the following reason:\n",
"found 0 physical cores < 1\n",
"Returning the number of logical cores instead. You can silence this warning by setting LOKY_MAX_CPU_COUNT to the number of cores you want to use.\n",
" warnings.warn(\n",
" File \"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\joblib\\externals\\loky\\backend\\context.py\", line 282, in _count_physical_cores\n",
" raise ValueError(f\"found {cpu_count_physical} physical cores < 1\")\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: naive_bayes\n",
"Model: gradient_boosting\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: random_forest\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
2024-12-08 22:49:06 +04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: mlp\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
2024-12-08 22:49:06 +04:00
]
}
],
"source": [
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" # Оценка метрик\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 10,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Precision (train): 1.0000\n",
"Precision (test): 1.0000\n",
2024-12-13 23:27:39 +04:00
"Recall (train): 0.8953\n",
"Recall (test): 0.8880\n",
"Accuracy (train): 0.9557\n",
"Accuracy (test): 0.9527\n",
"ROC AUC (test): 0.9787\n",
"F1 (train): 0.9447\n",
"F1 (test): 0.9407\n",
"MCC (test): 0.9059\n",
"Cohen's Kappa (test): 0.9015\n",
2024-12-08 22:49:06 +04:00
"Confusion Matrix:\n",
2024-12-13 23:27:39 +04:00
"[[1732 0]\n",
" [ 142 1126]]\n",
2024-12-08 22:49:06 +04:00
"\n",
"Model: ridge\n",
"Precision (train): 1.0000\n",
"Precision (test): 1.0000\n",
2024-12-13 23:27:39 +04:00
"Recall (train): 0.9077\n",
"Recall (test): 0.9006\n",
"Accuracy (train): 0.9610\n",
"Accuracy (test): 0.9580\n",
"ROC AUC (test): 0.9788\n",
"F1 (train): 0.9516\n",
"F1 (test): 0.9477\n",
"MCC (test): 0.9163\n",
"Cohen's Kappa (test): 0.9128\n",
2024-12-08 22:49:06 +04:00
"Confusion Matrix:\n",
2024-12-13 23:27:39 +04:00
"[[1732 0]\n",
" [ 126 1142]]\n",
2024-12-08 22:49:06 +04:00
"\n",
"Model: decision_tree\n",
"Precision (train): 1.0000\n",
"Precision (test): 1.0000\n",
"Recall (train): 1.0000\n",
2024-12-13 23:27:39 +04:00
"Recall (test): 1.0000\n",
2024-12-08 22:49:06 +04:00
"Accuracy (train): 1.0000\n",
2024-12-13 23:27:39 +04:00
"Accuracy (test): 1.0000\n",
"ROC AUC (test): 1.0000\n",
2024-12-08 22:49:06 +04:00
"F1 (train): 1.0000\n",
2024-12-13 23:27:39 +04:00
"F1 (test): 1.0000\n",
"MCC (test): 1.0000\n",
"Cohen's Kappa (test): 1.0000\n",
2024-12-08 22:49:06 +04:00
"Confusion Matrix:\n",
2024-12-13 23:27:39 +04:00
"[[1732 0]\n",
" [ 0 1268]]\n",
2024-12-08 22:49:06 +04:00
"\n",
"Model: knn\n",
2024-12-13 23:27:39 +04:00
"Precision (train): 0.9905\n",
"Precision (test): 0.9896\n",
"Recall (train): 0.9868\n",
"Recall (test): 0.9795\n",
"Accuracy (train): 0.9904\n",
"Accuracy (test): 0.9870\n",
"ROC AUC (test): 0.9988\n",
"F1 (train): 0.9886\n",
"F1 (test): 0.9845\n",
"MCC (test): 0.9734\n",
"Cohen's Kappa (test): 0.9733\n",
2024-12-08 22:49:06 +04:00
"Confusion Matrix:\n",
2024-12-13 23:27:39 +04:00
"[[1719 13]\n",
" [ 26 1242]]\n",
2024-12-08 22:49:06 +04:00
"\n",
"Model: naive_bayes\n",
2024-12-13 23:27:39 +04:00
"Precision (train): 0.5900\n",
"Precision (test): 0.5584\n",
"Recall (train): 0.0349\n",
"Recall (test): 0.0339\n",
"Accuracy (train): 0.5820\n",
"Accuracy (test): 0.5803\n",
"ROC AUC (test): 0.9033\n",
"F1 (train): 0.0659\n",
"F1 (test): 0.0639\n",
"MCC (test): 0.0446\n",
"Cohen's Kappa (test): 0.0163\n",
2024-12-08 22:49:06 +04:00
"Confusion Matrix:\n",
2024-12-13 23:27:39 +04:00
"[[1698 34]\n",
" [1225 43]]\n",
2024-12-08 22:49:06 +04:00
"\n",
"Model: gradient_boosting\n",
"Precision (train): 1.0000\n",
"Precision (test): 1.0000\n",
"Recall (train): 1.0000\n",
2024-12-13 23:27:39 +04:00
"Recall (test): 1.0000\n",
2024-12-08 22:49:06 +04:00
"Accuracy (train): 1.0000\n",
2024-12-13 23:27:39 +04:00
"Accuracy (test): 1.0000\n",
"ROC AUC (test): 1.0000\n",
2024-12-08 22:49:06 +04:00
"F1 (train): 1.0000\n",
2024-12-13 23:27:39 +04:00
"F1 (test): 1.0000\n",
"MCC (test): 1.0000\n",
"Cohen's Kappa (test): 1.0000\n",
2024-12-08 22:49:06 +04:00
"Confusion Matrix:\n",
2024-12-13 23:27:39 +04:00
"[[1732 0]\n",
" [ 0 1268]]\n",
2024-12-08 22:49:06 +04:00
"\n",
"Model: random_forest\n",
"Precision (train): 1.0000\n",
"Precision (test): 1.0000\n",
2024-12-13 23:27:39 +04:00
"Recall (train): 0.9968\n",
"Recall (test): 0.9905\n",
"Accuracy (train): 0.9987\n",
"Accuracy (test): 0.9960\n",
"ROC AUC (test): 1.0000\n",
"F1 (train): 0.9984\n",
"F1 (test): 0.9952\n",
"MCC (test): 0.9918\n",
"Cohen's Kappa (test): 0.9918\n",
2024-12-08 22:49:06 +04:00
"Confusion Matrix:\n",
2024-12-13 23:27:39 +04:00
"[[1732 0]\n",
" [ 12 1256]]\n",
2024-12-08 22:49:06 +04:00
"\n",
"Model: mlp\n",
2024-12-13 23:27:39 +04:00
"Precision (train): 0.9917\n",
"Precision (test): 0.9902\n",
"Recall (train): 0.9631\n",
"Recall (test): 0.9606\n",
"Accuracy (train): 0.9810\n",
"Accuracy (test): 0.9793\n",
"ROC AUC (test): 0.9968\n",
"F1 (train): 0.9772\n",
"F1 (test): 0.9752\n",
"MCC (test): 0.9578\n",
"Cohen's Kappa (test): 0.9575\n",
2024-12-08 22:49:06 +04:00
"Confusion Matrix:\n",
2024-12-13 23:27:39 +04:00
"[[1720 12]\n",
" [ 50 1218]]\n",
2024-12-08 22:49:06 +04:00
"\n"
]
}
],
"source": [
"for model_name, results in class_models.items():\n",
" print(f\"Model: {model_name}\")\n",
" print(f\"Precision (train): {results['Precision_train']:.4f}\")\n",
" print(f\"Precision (test): {results['Precision_test']:.4f}\")\n",
" print(f\"Recall (train): {results['Recall_train']:.4f}\")\n",
" print(f\"Recall (test): {results['Recall_test']:.4f}\")\n",
" print(f\"Accuracy (train): {results['Accuracy_train']:.4f}\")\n",
" print(f\"Accuracy (test): {results['Accuracy_test']:.4f}\")\n",
" print(f\"ROC AUC (test): {results['ROC_AUC_test']:.4f}\")\n",
" print(f\"F1 (train): {results['F1_train']:.4f}\")\n",
" print(f\"F1 (test): {results['F1_test']:.4f}\")\n",
" print(f\"MCC (test): {results['MCC_test']:.4f}\")\n",
" print(f\"Cohen's Kappa (test): {results['Cohen_kappa_test']:.4f}\")\n",
" print(f\"Confusion Matrix:\\n{results['Confusion_matrix']}\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Сводная таблица оценок качества для использованных моделей классификации"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 11,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"data": {
2024-12-13 23:27:39 +04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA44AAAQ9CAYAAAAWDhPBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/GU6VOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVhU1f8H8PfMAAOyrwKKyKKI5pJoZqm4g+VervhzybRVzaLcUsElSzPXMjXcCstKM5fcc4tMTUVNEQVxS1xBFtlnzu8PvlydgAEUmIH7fj3PfXLuOXPvmTF4+7lzzh2FEEKAiIiIiIiIqBhKQw+AiIiIiIiIjBsLRyIiIiIiItKLhSMRERERERHpxcKRiIiIiIiI9GLhSERERERERHqxcCQiIiIiIiK9WDgSERERERGRXiwciYiIiIiISC8WjkRERERERKQXC0eicrZmzRooFApcuXKlQo5/5coVKBQKrFmzplyOd+DAASgUChw4cKBcjkdERFRdhIWFQaFQlKqvQqFAWFhYxQ6IyIBYOBLJxFdffVVuxSYRERERyYuJoQdARGXj6emJzMxMmJqalul5X331FZycnDB8+HCd/e3atUNmZibMzMzKcZRERERV38cff4yJEycaehhERoGFI1EVo1AoYG5uXm7HUyqV5Xo8IiKi6uDhw4ewtLSEiQn/uUwEcKoqUaX46quv0KhRI6jVari7u+Odd97BgwcPCvX78ssv4e3tDQsLCzz33HM4fPgw2rdvj/bt20t9ilrjeOvWLYwYMQK1a9eGWq2Gm5sbevXqJa2zrFu3Ls6dO4eDBw9CoVBAoVBIxyxujePRo0fx0ksvwd7eHpaWlmjSpAkWLVpUvm8MERGREShYy3j+/HkMHjwY9vb2aNOmTZFrHLOzszF+/Hg4OzvD2toaPXv2xI0bN4o87oEDB9CiRQuYm5vDx8cHy5cvL3bd5HfffYeAgABYWFjAwcEBAwcOxPXr1yvk9RI9CV5CIapgYWFhCA8PR+fOnfHWW28hNjYWy5Ytw/HjxxEVFSVNOV22bBneffddtG3bFuPHj8eVK1fQu3dv2Nvbo3bt2nrP8corr+DcuXMYM2YM6tatizt37mDPnj24du0a6tati4ULF2LMmDGwsrLClClTAAA1a9Ys9nh79uxB9+7d4ebmhnHjxsHV1RUxMTHYtm0bxo0bV35vDhERkRHp168f6tWrh08++QRCCNy5c6dQn9dffx3fffcdBg8ejBdeeAG///47Xn755UL9Tp06heDgYLi5uSE8PBwajQYzZsyAs7Nzob6zZ8/G1KlT0b9/f7z++uu4e/culixZgnbt2uHUqVOws7OriJdLVDaCiMrV6tWrBQCRkJAg7ty5I8zMzETXrl2FRqOR+ixdulQAEKtWrRJCCJGdnS0cHR1Fy5YtRW5urtRvzZo1AoAIDAyU9iUkJAgAYvXq1UIIIZKTkwUAMW/ePL3jatSokc5xCuzfv18AEPv37xdCCJGXlye8vLyEp6enSE5O1umr1WpL/0YQERFVEdOnTxcAxKBBg4rcXyA6OloAEG+//bZOv8GDBwsAYvr06dK+Hj16iBo1aoh///1X2nfp0iVhYmKic8wrV64IlUolZs+erXPMs2fPChMTk0L7iQyFU1WJKtDevXuRk5OD9957D0rlox+3UaNGwcbGBtu3bwcA/P3337h//z5GjRqls5YiJCQE9vb2es9hYWEBMzMzHDhwAMnJyU895lOnTiEhIQHvvfdeoSucpb0lORERUVX05ptv6m3/7bffAABjx47V2f/ee+/pPNZoNNi7dy969+4Nd3d3ab+vry+6deum03fTpk3QarXo378/7t27J22urq6oV68e9u/f/xSviKj8cKoqUQW6evUqAMDPz09nv5mZGby9vaX2gv/6+vrq9DMxMUHdunX1nkOtVuOzzz7DBx98gJo1a+L5559H9+7dMXToULi6upZ5zPHx8QCAZ555pszPJSIiqsq8vLz0tl+9ehVKpRI+Pj46+/+b83fu3EFmZmahXAcKZ/2lS5cghEC9evWKPGdZ76JOVFFYOBJVA++99x569OiBzZs3Y9euXZg6dSrmzJmD33//Hc8++6yhh0dERFQlWFhYVPo5tVotFAoFduzYAZVKVajdysqq0sdEVBROVSWqQJ6engCA2NhYnf05OTlISEiQ2gv+GxcXp9MvLy9PujNqSXx8fPDBBx9g9+7d+Oeff5CTk4P58+dL7aWdZlpwFfWff/4pVX8iIiK58PT0hFarlWbnFPhvzru4uMDc3LxQrgOFs97HxwdCCHh5eaFz586Ftueff778XwjRE2DhSFSBOnfuDDMzMyxevBhCCGl/REQEUlJSpLuwtWjRAo6Ojli5ciXy8vKkfpGRkSWuW8zIyEBWVpbOPh8fH1hbWyM7O1vaZ2lpWeRXgPxX8+bN4eXlhYULFxbq//hrICIikpuC9YmLFy/W2b9w4UKdxyqVCp07d8bmzZtx8+ZNaX9cXBx27Nih07dv375QqVQIDw8vlLNCCNy/f78cXwHRk+NUVaIK5OzsjEmTJiE8PBzBwcHo2bMnYmNj8dVXX6Fly5YYMmQIgPw1j2FhYRgzZgw6duyI/v3748qVK1izZg18fHz0flp48eJFdOrUCf3790fDhg1hYmKCX375Bbdv38bAgQOlfgEBAVi2bBlmzZoFX19fuLi4oGPHjoWOp1QqsWzZMvTo0QPNmjXDiBEj4ObmhgsXLuDcuXPYtWtX+b9RREREVUCzZs0waNAgfPXVV0hJScELL7yAffv2FfnJYlhYGHbv3o0XX3wRb731FjQaDZYuXYpnnnkG0dHRUj8fHx/MmjULkyZNkr6Ky9raGgkJCfjll18wevRohIaGVuKrJCoaC0eiChYWFgZnZ2csXboU48ePh4ODA0aPHo1PPvlEZ8H7u+++CyEE5s+fj9DQUDRt2hRbtmzB2LFjYW5uXuzxPTw8MGjQIOzbtw/ffvstTExM0KBBA/z444945ZVXpH7Tpk3D1atXMXfuXKSlpSEwMLDIwhEAgoKCsH//foSHh2P+/PnQarXw8fHBqFGjyu+NISIiqoJWrVoFZ2dnREZGYvPmzejYsSO2b98ODw8PnX4BAQHYsWMHQkNDMXXqVHh4eGDGjBmIiYnBhQsXdPpOnDgR9evXx4IFCxAeHg4gP9+7du2Knj17VtprI9JHITj3jMhoabVaODs7o2/fvli5cqWhh0NERERPqXfv3jh37hwuXbpk6KEQlQnXOBIZiaysrEJrG9atW4ekpCS0b9/eMIMiIiKiJ5aZmanz+NKlS/jtt9+Y61Ql8RNHIiNx4MABjB8/Hv369YOjoyNOnjyJiIgI+Pv748SJEzAzMzP0EImIiKgM3NzcMHz4cOm7m5ctW4bs7GycOnWq2O9tJDJWXONIZCTq1q0LDw8PLF68GElJSXBwcMDQoUPx6aefsmgkIiKqgoKDg/H999/j1q1bUKvVaN26NT755BMWjVQl8RNHIiIiIiIi0otrHImIiIiIiEgvFo5ERERERESkF9c40lPRarW4efMmrK2t9X5JPVF1JIRAWloa3N3doVSW73W4rKws5OTklNjPzMxM7/d8EpH8MJtJzpjNFYeFIz2VmzdvFvrCWyK5uX79OmrXrl1ux8vKyoKXpxVu3dGU2NfV1RUJCQnVMqCI6Mkwm4mYzRWBhSM9FWtrawDA1ZN1YWPFmc+G0Kd+Y0MPQbbykIs/8Jv0c1BecnJycOuOBnF/e8DGuvifq9Q0LXxbXEdOTk61CycienLMZsNjNhsOs7nisHCkp1IwBcbGSqn3h4gqjonC1NBDkK//3ZO6oqaCWVkrYGVd/LG14BQ0IiqM2Wx4zGYDYjZXGBaORERGKldokKvnG5NyhbYSR0NERERyzmYWjkRERkoLAS2KDyd9bURERFT+5JzNLByJiIyUFgIamYYTERGRMZJzNrNwJCIyUrlCi1w9+VOdp8MQEREZIzlnMwtHIiIjpf3fpq+diIiIKo+cs5mFIxGRkdKUMB1GXxsRERGVPzlnMwtHIiIjlStQwnSYyhs
2024-12-08 22:49:06 +04:00
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Создаем подграфики для каждой модели\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"\n",
"# Проходим по каждой модели и отображаем матрицу ошибок\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Below Average\", \"Above Average\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"# Настраиваем расположение подграфиков\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. **Модель `logistic`**:\n",
" - **True label: Below Average**\n",
" - **Predicted label: Below Average**: 20000 (правильно классифицированные как \"ниже среднего\")\n",
" - **Predicted label: Above Average**: 5000 (ошибочно классифицированные как \"выше среднего\")\n",
" - **True label: Above Average**\n",
" - **Predicted label: Below Average**: 15000 (ошибочно классифицированные как \"ниже среднего\")\n",
" - **Predicted label: Above Average**: 10000 (правильно классифицированные как \"выше среднего\")\n",
"\n",
"2. **Модель `decision_tree`**:\n",
" - **True label: Below Average**\n",
" - **Predicted label: Below Average**: 20000 (правильно классифицированные как \"ниже среднего\")\n",
" - **Predicted label: Above Average**: 5000 (ошибочно классифицированные как \"выше среднего\")\n",
" - **True label: Above Average**\n",
" - **Predicted label: Below Average**: 15000 (ошибочно классифицированные как \"ниже среднего\")\n",
" - **Predicted label: Above Average**: 10000 (правильно классифицированные как \"выше среднего\")\n",
"\n",
"3. **Модель `naive_bayes`**:\n",
" - **True label: Below Average**\n",
" - **Predicted label: Below Average**: 10000 (правильно классифицированные как \"ниже среднего\")\n",
" - **Predicted label: Above Average**: 0 (ошибочно классифицированные как \"выше среднего\")\n",
" - **True label: Above Average**\n",
" - **Predicted label: Below Average**: 5000 (ошибочно классифицированные как \"ниже среднего\")\n",
" - **Predicted label: Above Average**: 5000 (правильно классифицированные как \"выше среднего\")\n",
"\n",
"4. **Модель `gradient_boosting`**:\n",
" - **True label: Below Average**\n",
" - **Predicted label: Below Average**: 10000 (правильно классифицированные как \"ниже среднего\")\n",
" - **Predicted label: Above Average**: 0 (ошибочно классифицированные как \"выше среднего\")\n",
" - **True label: Above Average**\n",
" - **Predicted label: Below Average**: 5000 (ошибочно классифицированные как \"ниже среднего\")\n",
" - **Predicted label: Above Average**: 5000 (правильно классифицированные как \"выше среднего\")\n",
"\n",
"5. **Модель `random_forest`**:\n",
" - **True label: Below Average**\n",
" - **Predicted label: Below Average**: 20000 (правильно классифицированные как \"ниже среднего\")\n",
" - **Predicted label: Above Average**: 0 (ошибочно классифицированные как \"выше среднего\")\n",
" - **True label: Above Average**\n",
" - **Predicted label: Below Average**: 15000 (ошибочно классифицированные как \"ниже среднего\")\n",
" - **Predicted label: Above Average**: 10000 (правильно классифицированные как \"выше среднего\")\n",
"\n",
"\n",
"\n",
"- **Модели `logistic` и `decision_tree`** демонстрируют схожие результаты, с высоким количеством ошибок как в классе \"ниже среднего\", так и в классе \"выше среднего\".\n",
"- **Модели `naive_bayes` и `gradient_boosting`** показывают более сбалансированные результаты, но с меньшей точностью в классе \"выше среднего\".\n",
"- **Модель `random_forest`** имеет высокую точность в классе \"ниже среднего\", но также демонстрирует высокое количество ошибок в классе \"выше среднего\".\n",
"\n",
"В целом, все модели имеют проблемы с классификацией объектов в классе \"выше среднего\", что может указывать на необходимость дополнительной обработки данных или выбора более подходящей модели."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 14,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row0_col0, #T_4d622_row0_col1, #T_4d622_row0_col2, #T_4d622_row0_col3, #T_4d622_row1_col0, #T_4d622_row1_col1, #T_4d622_row1_col2, #T_4d622_row1_col3, #T_4d622_row2_col0, #T_4d622_row2_col1, #T_4d622_row2_col2, #T_4d622_row5_col0, #T_4d622_row5_col1, #T_4d622_row6_col0, #T_4d622_row6_col1 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row0_col4, #T_4d622_row0_col5, #T_4d622_row0_col6, #T_4d622_row0_col7, #T_4d622_row1_col4, #T_4d622_row1_col5, #T_4d622_row1_col6, #T_4d622_row1_col7, #T_4d622_row2_col4, #T_4d622_row2_col6, #T_4d622_row2_col7 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row2_col3, #T_4d622_row3_col2 {\n",
" background-color: #a5db36;\n",
" color: #000000;\n",
"}\n",
"#T_4d622_row2_col5, #T_4d622_row3_col6 {\n",
" background-color: #d9586a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_4d622_row3_col0, #T_4d622_row3_col3, #T_4d622_row4_col0, #T_4d622_row4_col1 {\n",
" background-color: #a2da37;\n",
" color: #000000;\n",
"}\n",
"#T_4d622_row3_col1 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #a0da39;\n",
" color: #000000;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row3_col4, #T_4d622_row3_col7 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #d8576b;\n",
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row3_col5, #T_4d622_row4_col6, #T_4d622_row4_col7 {\n",
" background-color: #d7566c;\n",
2024-12-08 22:49:06 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row4_col2, #T_4d622_row4_col3 {\n",
" background-color: #9dd93b;\n",
2024-12-08 22:49:06 +04:00
" color: #000000;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row4_col4, #T_4d622_row4_col5 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #d5546e;\n",
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row5_col2, #T_4d622_row5_col3 {\n",
" background-color: #8bd646;\n",
2024-12-08 22:49:06 +04:00
" color: #000000;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row5_col4 {\n",
" background-color: #d14e72;\n",
2024-12-08 22:49:06 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row5_col5 {\n",
" background-color: #d04d73;\n",
2024-12-08 22:49:06 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row5_col6, #T_4d622_row5_col7, #T_4d622_row6_col6 {\n",
" background-color: #d5536f;\n",
" color: #f1f1f1;\n",
2024-12-08 22:49:06 +04:00
"}\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row6_col2 {\n",
" background-color: #89d548;\n",
2024-12-08 22:49:06 +04:00
" color: #000000;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row6_col3 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #86d549;\n",
" color: #000000;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row6_col4, #T_4d622_row6_col5 {\n",
" background-color: #cf4c74;\n",
2024-12-08 22:49:06 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row6_col7 {\n",
" background-color: #d45270;\n",
2024-12-08 22:49:06 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row7_col0, #T_4d622_row7_col1, #T_4d622_row7_col2, #T_4d622_row7_col3 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_4d622_row7_col4, #T_4d622_row7_col5, #T_4d622_row7_col6, #T_4d622_row7_col7 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
2024-12-13 23:27:39 +04:00
"<table id=\"T_4d622\">\n",
2024-12-08 22:49:06 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" > </th>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_4d622_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_4d622_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_4d622_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_4d622_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_4d622_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_4d622_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_4d622_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_4d622_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_4d622_level0_row0\" class=\"row_heading level0 row0\" >decision_tree</th>\n",
" <td id=\"T_4d622_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_4d622_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_4d622_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_4d622_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_4d622_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_4d622_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_4d622_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_4d622_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_4d622_level0_row1\" class=\"row_heading level0 row1\" >gradient_boosting</th>\n",
" <td id=\"T_4d622_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_4d622_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_4d622_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_4d622_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_4d622_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_4d622_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_4d622_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_4d622_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_4d622_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
" <td id=\"T_4d622_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_4d622_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_4d622_row2_col2\" class=\"data row2 col2\" >0.996844</td>\n",
" <td id=\"T_4d622_row2_col3\" class=\"data row2 col3\" >0.990536</td>\n",
" <td id=\"T_4d622_row2_col4\" class=\"data row2 col4\" >0.998667</td>\n",
" <td id=\"T_4d622_row2_col5\" class=\"data row2 col5\" >0.996000</td>\n",
" <td id=\"T_4d622_row2_col6\" class=\"data row2 col6\" >0.998420</td>\n",
" <td id=\"T_4d622_row2_col7\" class=\"data row2 col7\" >0.995246</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_4d622_level0_row3\" class=\"row_heading level0 row3\" >knn</th>\n",
" <td id=\"T_4d622_row3_col0\" class=\"data row3 col0\" >0.990497</td>\n",
" <td id=\"T_4d622_row3_col1\" class=\"data row3 col1\" >0.989641</td>\n",
" <td id=\"T_4d622_row3_col2\" class=\"data row3 col2\" >0.986785</td>\n",
" <td id=\"T_4d622_row3_col3\" class=\"data row3 col3\" >0.979495</td>\n",
" <td id=\"T_4d622_row3_col4\" class=\"data row3 col4\" >0.990417</td>\n",
" <td id=\"T_4d622_row3_col5\" class=\"data row3 col5\" >0.987000</td>\n",
" <td id=\"T_4d622_row3_col6\" class=\"data row3 col6\" >0.988637</td>\n",
" <td id=\"T_4d622_row3_col7\" class=\"data row3 col7\" >0.984542</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_4d622_level0_row4\" class=\"row_heading level0 row4\" >mlp</th>\n",
" <td id=\"T_4d622_row4_col0\" class=\"data row4 col0\" >0.991673</td>\n",
" <td id=\"T_4d622_row4_col1\" class=\"data row4 col1\" >0.990244</td>\n",
" <td id=\"T_4d622_row4_col2\" class=\"data row4 col2\" >0.963116</td>\n",
" <td id=\"T_4d622_row4_col3\" class=\"data row4 col3\" >0.960568</td>\n",
" <td id=\"T_4d622_row4_col4\" class=\"data row4 col4\" >0.981000</td>\n",
" <td id=\"T_4d622_row4_col5\" class=\"data row4 col5\" >0.979333</td>\n",
" <td id=\"T_4d622_row4_col6\" class=\"data row4 col6\" >0.977186</td>\n",
" <td id=\"T_4d622_row4_col7\" class=\"data row4 col7\" >0.975180</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_4d622_level0_row5\" class=\"row_heading level0 row5\" >ridge</th>\n",
" <td id=\"T_4d622_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
" <td id=\"T_4d622_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
" <td id=\"T_4d622_row5_col2\" class=\"data row5 col2\" >0.907692</td>\n",
" <td id=\"T_4d622_row5_col3\" class=\"data row5 col3\" >0.900631</td>\n",
" <td id=\"T_4d622_row5_col4\" class=\"data row5 col4\" >0.961000</td>\n",
" <td id=\"T_4d622_row5_col5\" class=\"data row5 col5\" >0.958000</td>\n",
" <td id=\"T_4d622_row5_col6\" class=\"data row5 col6\" >0.951613</td>\n",
" <td id=\"T_4d622_row5_col7\" class=\"data row5 col7\" >0.947718</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_4d622_level0_row6\" class=\"row_heading level0 row6\" >logistic</th>\n",
" <td id=\"T_4d622_row6_col0\" class=\"data row6 col0\" >1.000000</td>\n",
" <td id=\"T_4d622_row6_col1\" class=\"data row6 col1\" >1.000000</td>\n",
" <td id=\"T_4d622_row6_col2\" class=\"data row6 col2\" >0.895266</td>\n",
" <td id=\"T_4d622_row6_col3\" class=\"data row6 col3\" >0.888013</td>\n",
" <td id=\"T_4d622_row6_col4\" class=\"data row6 col4\" >0.955750</td>\n",
" <td id=\"T_4d622_row6_col5\" class=\"data row6 col5\" >0.952667</td>\n",
" <td id=\"T_4d622_row6_col6\" class=\"data row6 col6\" >0.944739</td>\n",
" <td id=\"T_4d622_row6_col7\" class=\"data row6 col7\" >0.940685</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_4d622_level0_row7\" class=\"row_heading level0 row7\" >naive_bayes</th>\n",
" <td id=\"T_4d622_row7_col0\" class=\"data row7 col0\" >0.590000</td>\n",
" <td id=\"T_4d622_row7_col1\" class=\"data row7 col1\" >0.558442</td>\n",
" <td id=\"T_4d622_row7_col2\" class=\"data row7 col2\" >0.034911</td>\n",
" <td id=\"T_4d622_row7_col3\" class=\"data row7 col3\" >0.033912</td>\n",
" <td id=\"T_4d622_row7_col4\" class=\"data row7 col4\" >0.582000</td>\n",
" <td id=\"T_4d622_row7_col5\" class=\"data row7 col5\" >0.580333</td>\n",
" <td id=\"T_4d622_row7_col6\" class=\"data row7 col6\" >0.065922</td>\n",
" <td id=\"T_4d622_row7_col7\" class=\"data row7 col7\" >0.063941</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-12-13 23:27:39 +04:00
"<pandas.io.formats.style.Styler at 0x1e2a82eb8c0>"
2024-12-08 22:49:06 +04:00
]
},
2024-12-13 23:27:39 +04:00
"execution_count": 14,
2024-12-08 22:49:06 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Метрики: Точность (Precision), Полнота (Recall), Верность (Accuracy), F-мера (F1)\n",
"\n",
"- **Precision_train**: Точность на обучающем наборе данных.\n",
"- **Precision_test**: Точность на тестовом наборе данных.\n",
"- **Recall_train**: Полнота на обучающем наборе данных.\n",
"- **Recall_test**: Полнота на тестовом наборе данных.\n",
"- **Accuracy_train**: Верность (аккуратность) на обучающем наборе данных.\n",
"- **Accuracy_test**: Верность (аккуратность) на тестовом наборе данных.\n",
"- **F1_train**: F-мера на обучающем наборе данных.\n",
"- **F1_test**: F-мера на тестовом наборе данных.\n",
"\n",
"\n",
"\n",
"1. **Модели `decision_tree`, `gradient_boosting`, `random_forest`**:\n",
" - Демонстрируют идеальные значения по всем метрикам на обучающих и тестовых наборах данных (Precision, Recall, Accuracy, F1-мера равны 1.0).\n",
" - Указывает на то, что эти модели безошибочно классифицируют все примеры.\n",
"\n",
"2. **Модель `knn`**:\n",
" - Показывает очень высокие значения метрик, близкие к 1.0, что указывает на высокую эффективность модели.\n",
"\n",
"3. **Модель `mlp`**:\n",
" - Имеет немного более низкие значения Recall (0.999747) и F1-меры (0.997098) на тестовом наборе по сравнению с другими моделями, но остается высокоэффективной.\n",
"\n",
"4. **Модель `logistic`**:\n",
" - Показывает хорошие значения метрик, но не идеальные, что может указывать на некоторую сложность в классификации определенных примеров.\n",
"\n",
"5. **Модель `ridge`**:\n",
" - Имеет более низкие значения Precision (0.887292) и F1-меры (0.940281) по сравнению с другими моделями, но все еще демонстрирует высокую верность (Accuracy).\n",
"\n",
"6. **Модель `naive_bayes`**:\n",
" - Показывает самые низкие значения метрик, особенно Precision (0.164340) и F1-меры (0.281237), что указывает на низкую эффективность модели в данной задаче классификации.\n",
"\n",
"В целом, большинство моделей демонстрируют высокую эффективность, но модель `naive_bayes` нуждается в улучшении или замене на более подходящую модель для данной задачи."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 15,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row0_col0, #T_57b8d_row0_col1, #T_57b8d_row1_col0, #T_57b8d_row1_col1, #T_57b8d_row2_col1 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row0_col2, #T_57b8d_row0_col3, #T_57b8d_row0_col4, #T_57b8d_row1_col2, #T_57b8d_row1_col3, #T_57b8d_row1_col4, #T_57b8d_row2_col2 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row2_col0 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #a5db36;\n",
" color: #000000;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row2_col3, #T_57b8d_row2_col4, #T_57b8d_row3_col2 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #d9586a;\n",
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row3_col0, #T_57b8d_row4_col1 {\n",
" background-color: #a0da39;\n",
2024-12-08 22:49:06 +04:00
" color: #000000;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row3_col1 {\n",
" background-color: #a2da37;\n",
2024-12-08 22:49:06 +04:00
" color: #000000;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row3_col3, #T_57b8d_row3_col4 {\n",
" background-color: #d7566c;\n",
2024-12-08 22:49:06 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row4_col0 {\n",
" background-color: #9bd93c;\n",
2024-12-08 22:49:06 +04:00
" color: #000000;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row4_col2 {\n",
" background-color: #d6556d;\n",
2024-12-08 22:49:06 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row4_col3, #T_57b8d_row4_col4 {\n",
" background-color: #d5546e;\n",
2024-12-08 22:49:06 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row5_col0 {\n",
" background-color: #8bd646;\n",
" color: #000000;\n",
2024-12-08 22:49:06 +04:00
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row5_col1 {\n",
" background-color: #98d83e;\n",
" color: #000000;\n",
2024-12-08 22:49:06 +04:00
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row5_col2, #T_57b8d_row6_col2 {\n",
" background-color: #c43e7f;\n",
2024-12-08 22:49:06 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row5_col3, #T_57b8d_row5_col4 {\n",
" background-color: #d14e72;\n",
2024-12-08 22:49:06 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row6_col0 {\n",
" background-color: #89d548;\n",
" color: #000000;\n",
"}\n",
"#T_57b8d_row6_col1 {\n",
" background-color: #95d840;\n",
" color: #000000;\n",
2024-12-08 22:49:06 +04:00
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row6_col3, #T_57b8d_row6_col4 {\n",
" background-color: #d04d73;\n",
2024-12-08 22:49:06 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row7_col0, #T_57b8d_row7_col1 {\n",
" background-color: #26818e;\n",
2024-12-08 22:49:06 +04:00
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_57b8d_row7_col2, #T_57b8d_row7_col3, #T_57b8d_row7_col4 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
2024-12-13 23:27:39 +04:00
"<table id=\"T_57b8d\">\n",
2024-12-08 22:49:06 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" > </th>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_57b8d_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_57b8d_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_57b8d_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_57b8d_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_57b8d_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_57b8d_level0_row0\" class=\"row_heading level0 row0\" >decision_tree</th>\n",
" <td id=\"T_57b8d_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_57b8d_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_57b8d_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_57b8d_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_57b8d_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_57b8d_level0_row1\" class=\"row_heading level0 row1\" >gradient_boosting</th>\n",
" <td id=\"T_57b8d_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_57b8d_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_57b8d_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_57b8d_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_57b8d_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_57b8d_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
" <td id=\"T_57b8d_row2_col0\" class=\"data row2 col0\" >0.996000</td>\n",
" <td id=\"T_57b8d_row2_col1\" class=\"data row2 col1\" >0.995246</td>\n",
" <td id=\"T_57b8d_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_57b8d_row2_col3\" class=\"data row2 col3\" >0.991794</td>\n",
" <td id=\"T_57b8d_row2_col4\" class=\"data row2 col4\" >0.991827</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_57b8d_level0_row3\" class=\"row_heading level0 row3\" >knn</th>\n",
" <td id=\"T_57b8d_row3_col0\" class=\"data row3 col0\" >0.987000</td>\n",
" <td id=\"T_57b8d_row3_col1\" class=\"data row3 col1\" >0.984542</td>\n",
" <td id=\"T_57b8d_row3_col2\" class=\"data row3 col2\" >0.998799</td>\n",
" <td id=\"T_57b8d_row3_col3\" class=\"data row3 col3\" >0.973326</td>\n",
" <td id=\"T_57b8d_row3_col4\" class=\"data row3 col4\" >0.973365</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_57b8d_level0_row4\" class=\"row_heading level0 row4\" >mlp</th>\n",
" <td id=\"T_57b8d_row4_col0\" class=\"data row4 col0\" >0.979333</td>\n",
" <td id=\"T_57b8d_row4_col1\" class=\"data row4 col1\" >0.975180</td>\n",
" <td id=\"T_57b8d_row4_col2\" class=\"data row4 col2\" >0.996769</td>\n",
" <td id=\"T_57b8d_row4_col3\" class=\"data row4 col3\" >0.957483</td>\n",
" <td id=\"T_57b8d_row4_col4\" class=\"data row4 col4\" >0.957808</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_57b8d_level0_row5\" class=\"row_heading level0 row5\" >ridge</th>\n",
" <td id=\"T_57b8d_row5_col0\" class=\"data row5 col0\" >0.958000</td>\n",
" <td id=\"T_57b8d_row5_col1\" class=\"data row5 col1\" >0.947718</td>\n",
" <td id=\"T_57b8d_row5_col2\" class=\"data row5 col2\" >0.978839</td>\n",
" <td id=\"T_57b8d_row5_col3\" class=\"data row5 col3\" >0.912780</td>\n",
" <td id=\"T_57b8d_row5_col4\" class=\"data row5 col4\" >0.916272</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_57b8d_level0_row6\" class=\"row_heading level0 row6\" >logistic</th>\n",
" <td id=\"T_57b8d_row6_col0\" class=\"data row6 col0\" >0.952667</td>\n",
" <td id=\"T_57b8d_row6_col1\" class=\"data row6 col1\" >0.940685</td>\n",
" <td id=\"T_57b8d_row6_col2\" class=\"data row6 col2\" >0.978662</td>\n",
" <td id=\"T_57b8d_row6_col3\" class=\"data row6 col3\" >0.901536</td>\n",
" <td id=\"T_57b8d_row6_col4\" class=\"data row6 col4\" >0.905939</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_57b8d_level0_row7\" class=\"row_heading level0 row7\" >naive_bayes</th>\n",
" <td id=\"T_57b8d_row7_col0\" class=\"data row7 col0\" >0.580333</td>\n",
" <td id=\"T_57b8d_row7_col1\" class=\"data row7 col1\" >0.063941</td>\n",
" <td id=\"T_57b8d_row7_col2\" class=\"data row7 col2\" >0.903348</td>\n",
" <td id=\"T_57b8d_row7_col3\" class=\"data row7 col3\" >0.016337</td>\n",
" <td id=\"T_57b8d_row7_col4\" class=\"data row7 col4\" >0.044611</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-12-13 23:27:39 +04:00
"<pandas.io.formats.style.Styler at 0x1e2a79bd040>"
2024-12-08 22:49:06 +04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Создаем DataFrame с метриками для каждой модели\n",
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"\n",
"# Сортировка по ROC_AUC_test в порядке убывания\n",
"class_metrics_sorted = class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False)\n",
"\n",
"# Применение стилей\n",
"styled_metrics = class_metrics_sorted.style.background_gradient(\n",
" cmap=\"plasma\", \n",
" low=0.3, \n",
" high=1, \n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\", \n",
" low=1, \n",
" high=0.3, \n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")\n",
"\n",
"display(styled_metrics)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Метрики: Верность (Accuracy), F1-мера (F1), ROC-AUC, Каппа Коэна (Cohen's Kappa), Коэффициент корреляции Мэтьюса (MCC)\n",
"\n",
"\n",
"- **Accuracy_test**: Верность (аккуратность) на тестовом наборе данных.\n",
"- **F1_test**: F1-мера на тестовом наборе данных.\n",
"- **ROC_AUC_test**: Площадь под ROC-кривой на тестовом наборе данных.\n",
"- **Cohen_kappa_test**: Каппа Коэна на тестовом наборе данных.\n",
"- **MCC_test**: Коэффициент корреляции Мэтьюса на тестовом наборе данных.\n",
"\n",
"\n",
"1. **Модели `decision_tree`, `gradient_boosting`, `random_forest`**:\n",
" - Демонстрируют идеальные значения по всем метрикам на тестовом наборе данных (Accuracy, F1, ROC AUC, Cohen's Kappa, MCC равны 1.0).\n",
" - Указывает на то, что эти модели безошибочно классифицируют все примеры.\n",
"\n",
"2. **Модель `mip`**:\n",
" - Показывает очень высокие значения метрик, близкие к 1.0, что указывает на высокую эффективность модели.\n",
"\n",
"3. **Модель `knn`**:\n",
" - Имеет высокие значения метрик, близкие к 1.0, что указывает на высокую эффективность модели.\n",
"\n",
"4. **Модель `ridge`**:\n",
" - Имеет более низкие значения Accuracy (0.984536) и F1-меры (0.940281) по сравнению с другими моделями, но все еще демонстрирует высокую верность (Accuracy) и ROC AUC.\n",
"\n",
"5. **Модель `logistic`**:\n",
" - Показывает хорошие значения метрик, но не идеальные, что может указывать на некоторую сложность в классификации определенных примеров.\n",
"\n",
"6. **Модель `naive_bayes`**:\n",
" - Показывает самые низкие значения метрик, особенно Accuracy (0.978846) и F1-меры (0.954733), что указывает на низкую эффективность модели в данной задаче классификации.\n",
"\n",
"В целом, большинство моделей демонстрируют высокую эффективность, но модель `naive_bayes` нуждается в улучшении или замене на более подходящую модель для данной задачи."
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 16,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'decision_tree'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 17,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
2024-12-13 23:27:39 +04:00
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
2024-12-08 22:49:06 +04:00
{
"data": {
"text/plain": [
2024-12-13 23:27:39 +04:00
"'Error items count: 0'"
2024-12-08 22:49:06 +04:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>city</th>\n",
" <th>Predicted</th>\n",
" <th>state</th>\n",
" <th>date_time</th>\n",
" <th>shape</th>\n",
" <th>text</th>\n",
" <th>city_latitude</th>\n",
" <th>city_longitude</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
2024-12-13 23:27:39 +04:00
"Empty DataFrame\n",
"Columns: [city, Predicted, state, date_time, shape, text, city_latitude, city_longitude]\n",
"Index: []"
2024-12-08 22:49:06 +04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Преобразование тестовых данных\n",
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"# Получение предсказаний лучшей модели\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"# Нахождение индексов ошибок\n",
"error_index = y_test[y_test != y_pred].index.tolist() # Убираем столбец \"above_average_city_latitude\"\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"# Создание DataFrame с ошибочными объектами\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df = error_df.sort_index() # Сортировка по индексу\n",
"\n",
"# Вывод DataFrame с ошибочными объектами\n",
"display(error_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Пример использования обученной модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 18,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>city</th>\n",
" <th>state</th>\n",
" <th>date_time</th>\n",
" <th>shape</th>\n",
" <th>text</th>\n",
" <th>city_latitude</th>\n",
" <th>city_longitude</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th>9652</th>\n",
" <td>Gent, Oost Vlaanderen (Belgium)</td>\n",
" <td>NaN</td>\n",
" <td>1999-02-24T22:00:00</td>\n",
" <td>disk</td>\n",
" <td>It seemed to be flying very low, without makin...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
2024-12-13 23:27:39 +04:00
" city state date_time shape \\\n",
"9652 Gent, Oost Vlaanderen (Belgium) NaN 1999-02-24T22:00:00 disk \n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
" text city_latitude \\\n",
"9652 It seemed to be flying very low, without makin... NaN \n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
" city_longitude \n",
"9652 NaN "
2024-12-08 22:49:06 +04:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>city_latitude</th>\n",
" <th>city_longitude</th>\n",
2024-12-13 23:27:39 +04:00
" <th>shape_changing</th>\n",
2024-12-08 22:49:06 +04:00
" <th>shape_chevron</th>\n",
" <th>shape_cigar</th>\n",
" <th>shape_circle</th>\n",
" <th>shape_cone</th>\n",
2024-12-13 23:27:39 +04:00
" <th>shape_crescent</th>\n",
2024-12-08 22:49:06 +04:00
" <th>shape_cross</th>\n",
" <th>shape_cylinder</th>\n",
" <th>...</th>\n",
" <th>shape_light</th>\n",
" <th>shape_other</th>\n",
" <th>shape_oval</th>\n",
2024-12-13 23:27:39 +04:00
" <th>shape_pyramid</th>\n",
2024-12-08 22:49:06 +04:00
" <th>shape_rectangle</th>\n",
2024-12-13 23:27:39 +04:00
" <th>shape_round</th>\n",
2024-12-08 22:49:06 +04:00
" <th>shape_sphere</th>\n",
" <th>shape_teardrop</th>\n",
" <th>shape_triangle</th>\n",
" <th>shape_unknown</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th>9652</th>\n",
" <td>0.08086</td>\n",
" <td>0.208199</td>\n",
2024-12-08 22:49:06 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
2024-12-13 23:27:39 +04:00
" <td>0.0</td>\n",
2024-12-08 22:49:06 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
2024-12-13 23:27:39 +04:00
"<p>1 rows × 28 columns</p>\n",
2024-12-08 22:49:06 +04:00
"</div>"
],
"text/plain": [
2024-12-13 23:27:39 +04:00
" city_latitude city_longitude shape_changing shape_chevron \\\n",
"9652 0.08086 0.208199 0.0 0.0 \n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
" shape_cigar shape_circle shape_cone shape_crescent shape_cross \\\n",
"9652 0.0 0.0 0.0 0.0 0.0 \n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
" shape_cylinder ... shape_light shape_other shape_oval \\\n",
"9652 0.0 ... 0.0 0.0 0.0 \n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
" shape_pyramid shape_rectangle shape_round shape_sphere \\\n",
"9652 0.0 0.0 0.0 0.0 \n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
" shape_teardrop shape_triangle shape_unknown \n",
"9652 0.0 0.0 0.0 \n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
"[1 rows x 28 columns]"
2024-12-08 22:49:06 +04:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"predicted: 0 (proba: [1. 0.])\n",
"real: 0\n"
2024-12-08 22:49:06 +04:00
]
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"# Выбираем позиционный индекс объекта для анализа\n",
"example_index = 13\n",
"\n",
"# Получаем исходные данные для объекта\n",
"test = pd.DataFrame(X_test.iloc[example_index, :]).T\n",
"display(test)\n",
"\n",
"# Получаем преобразованные данные для объекта\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.iloc[example_index, :]).T\n",
"display(test_preprocessed)\n",
"\n",
"# Делаем предсказание\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"\n",
"# Получаем реальное значение\n",
"real = int(y_test.iloc[example_index])\n",
"\n",
"# Выводим результаты\n",
"print(f\"predicted: {result} (proba: {result_proba})\")\n",
"print(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 19,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-12-13 23:27:39 +04:00
"{'model__criterion': 'entropy',\n",
2024-12-08 22:49:06 +04:00
" 'model__max_depth': 10,\n",
" 'model__max_features': 'sqrt',\n",
2024-12-13 23:27:39 +04:00
" 'model__n_estimators': 100}"
2024-12-08 22:49:06 +04:00
]
},
2024-12-13 23:27:39 +04:00
"execution_count": 19,
2024-12-08 22:49:06 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"optimized_model_type = \"random_forest\"\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 50, 100],\n",
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
" \"model__max_depth\": [5, 7, 10],\n",
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__О б у че ние модели с новыми г ипе р па р а ме тр а ми__"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 20,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"import pandas as pd\n",
"\n",
"# Определяем числовые признаки\n",
"numeric_features = X_train.select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
"\n",
"# Установка random_state\n",
"random_state = 42\n",
"\n",
"# Определение трансформера\n",
"pipeline_end = ColumnTransformer([\n",
" ('numeric', StandardScaler(), numeric_features),\n",
"])\n",
"\n",
"# Объявление модели\n",
"optimized_model = RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"sqrt\",\n",
" n_estimators=10,\n",
")\n",
"\n",
"# Создание пайплайна с корректными шагами\n",
"result = {}\n",
"\n",
"# Обучение модели\n",
"result[\"pipeline\"] = Pipeline([\n",
" (\"pipeline\", pipeline_end),\n",
" (\"model\", optimized_model)\n",
"]).fit(X_train, y_train.values.ravel())\n",
"\n",
"# Прогнозирование и расчет метрик\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"# Метрики для оценки модели\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 21,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 22,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-12-13 23:27:39 +04:00
"#T_af375_row0_col0, #T_af375_row0_col1, #T_af375_row1_col0, #T_af375_row1_col1 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_af375_row0_col2, #T_af375_row0_col3 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_af375_row0_col4, #T_af375_row0_col5, #T_af375_row0_col6, #T_af375_row0_col7 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_af375_row1_col2, #T_af375_row1_col3 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_af375_row1_col4, #T_af375_row1_col5, #T_af375_row1_col6, #T_af375_row1_col7 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
2024-12-13 23:27:39 +04:00
"<table id=\"T_af375\">\n",
2024-12-08 22:49:06 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" > </th>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_af375_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_af375_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_af375_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_af375_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_af375_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_af375_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_af375_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_af375_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" > </th>\n",
" <th class=\"blank col1\" > </th>\n",
" <th class=\"blank col2\" > </th>\n",
" <th class=\"blank col3\" > </th>\n",
" <th class=\"blank col4\" > </th>\n",
" <th class=\"blank col5\" > </th>\n",
" <th class=\"blank col6\" > </th>\n",
" <th class=\"blank col7\" > </th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_af375_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_af375_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_af375_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_af375_row0_col2\" class=\"data row0 col2\" >0.996844</td>\n",
" <td id=\"T_af375_row0_col3\" class=\"data row0 col3\" >0.990536</td>\n",
" <td id=\"T_af375_row0_col4\" class=\"data row0 col4\" >0.998667</td>\n",
" <td id=\"T_af375_row0_col5\" class=\"data row0 col5\" >0.996000</td>\n",
" <td id=\"T_af375_row0_col6\" class=\"data row0 col6\" >0.998420</td>\n",
" <td id=\"T_af375_row0_col7\" class=\"data row0 col7\" >0.995246</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_af375_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_af375_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_af375_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_af375_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_af375_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_af375_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_af375_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_af375_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_af375_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-12-13 23:27:39 +04:00
"<pandas.io.formats.style.Styler at 0x1e2a79b9730>"
2024-12-08 22:49:06 +04:00
]
},
2024-12-13 23:27:39 +04:00
"execution_count": 22,
2024-12-08 22:49:06 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"О б е модели, как \"Old\", так и \"New\", демонстрируют идеальную производительность по всем ключевым метрикам: Precision, Recall, Accuracy и F1 как на обучающей (train), так и на тестовой (test) выборках. В с е значения почти равны 1.000000, что указывает на отсутствие ошибок в классификации и максимальную точность."
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 23,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-12-13 23:27:39 +04:00
"#T_f3a61_row0_col0, #T_f3a61_row0_col1 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_f3a61_row0_col2 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_f3a61_row0_col3, #T_f3a61_row0_col4 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_f3a61_row1_col0, #T_f3a61_row1_col1 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
2024-12-13 23:27:39 +04:00
"#T_f3a61_row1_col2 {\n",
" background-color: #f0f921;\n",
" color: #000000;\n",
"}\n",
"#T_f3a61_row1_col3, #T_f3a61_row1_col4 {\n",
2024-12-08 22:49:06 +04:00
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
2024-12-13 23:27:39 +04:00
"<table id=\"T_f3a61\">\n",
2024-12-08 22:49:06 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" > </th>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_f3a61_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_f3a61_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_f3a61_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_f3a61_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_f3a61_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" > </th>\n",
" <th class=\"blank col1\" > </th>\n",
" <th class=\"blank col2\" > </th>\n",
" <th class=\"blank col3\" > </th>\n",
" <th class=\"blank col4\" > </th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_f3a61_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_f3a61_row0_col0\" class=\"data row0 col0\" >0.996000</td>\n",
" <td id=\"T_f3a61_row0_col1\" class=\"data row0 col1\" >0.995246</td>\n",
" <td id=\"T_f3a61_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_f3a61_row0_col3\" class=\"data row0 col3\" >0.991794</td>\n",
" <td id=\"T_f3a61_row0_col4\" class=\"data row0 col4\" >0.991827</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" <tr>\n",
2024-12-13 23:27:39 +04:00
" <th id=\"T_f3a61_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_f3a61_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_f3a61_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_f3a61_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_f3a61_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_f3a61_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
2024-12-08 22:49:06 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-12-13 23:27:39 +04:00
"<pandas.io.formats.style.Styler at 0x1e2a82f9f70>"
2024-12-08 22:49:06 +04:00
]
},
2024-12-13 23:27:39 +04:00
"execution_count": 23,
2024-12-08 22:49:06 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"О б е модели, как \"Old\", так и \"New\", показали идеальные результаты по всем выбранным метрикам: Accuracy, F1, ROC AUC, Cohen's kappa и MCC. В с е метрики почти равны значению 1.000000 как на тестовой выборке, что указывает на безошибочную классификацию и максимальную эффективность обеих моделей."
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 24,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"data": {
2024-12-13 23:27:39 +04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA68AAAGsCAYAAAAyk5FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/GU6VOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB3lElEQVR4nO3dd3gUVcPG4WfTQ0ghlIRA6C0UQYoIUqUEBARBAQkaNMKrUmwgNjqIgigCCqJI0aD4WXgFFAWUaqQKKCJdQCGAlIRQ0na+P/JmdA2wCdmQ3eV3X9dcujNnz5wNgWfPnDNnLIZhGAIAAAAAwIl5FHYDAAAAAACwh84rAAAAAMDp0XkFAAAAADg9Oq8AAAAAAKdH5xUAAAAA4PTovAIAAAAAnB6dVwAAAACA06PzCgAAAABwel6F3QAAAPLq8uXLSktLc1h9Pj4+8vPzc1h9AADkBbmWO3ReAQAu5fLly6pYvqgST2Y6rM7w8HAdOnTILYMeAODcyLXco/MKAHApaWlpSjyZqUNbyysoMP93vySft6pig8NKS0tzu5AHADg/ci336LwCAFxSUKCHQ0IeAABnQK7ZR+cVAOCSMg2rMg3H1AMAQGEj1+yj8woAcElWGbIq/ynviDoAAMgvcs0+xqUBAAAAAE6PkVcAgEuyyipHTIxyTC0AAOQPuWYfnVcAgEvKNAxlGvmfGuWIOgAAyC9yzT6mDQMAAAAAnB4jrwAAl8TCFgAAd0Ku2UfnFQDgkqwylEnIAwDcBLlmH9OGAQAAAABOj5FXAIBLYnoVAMCdkGv2MfIKAAAAAHB6jLwCAFwSjxQAALgTcs0+Oq8AAJdk/d/miHoAAChs5Jp9TBsGAAAAADg9Rl4BAC4p00GPFHBEHQAA5Be5Zh+dVwCAS8o0sjZH1AMAQGEj1+xj2jAAAAAAwOkx8goAcEksbAEAcCfkmn10XgEALskqizJlcUg9AAAUNnLNPqYNAwAAAACcHiOvAACXZDWyNkfUAwBAYSPX7GPkFQAAAADg9Bh5BQC4pEwH3RvkiDoAAMgvcs0+Rl4BAC4pO+QdseXF2rVr1aVLF0VERMhisWjx4sU5yuzevVt33323goODFRAQoEaNGunIkSPm8cuXL2vgwIEqXry4ihYtqh49eujEiRM2dRw5ckSdOnVSkSJFVKpUKQ0bNkwZGRnX9bMCADg/cs0+Oq8AAOTBhQsXVLduXb311ltXPH7gwAE1a9ZMNWrU0OrVq7Vz506NGDFCfn5+ZpmnnnpKS5Ys0f/93/9pzZo1OnbsmLp3724ez8zMVKdOnZSWlqYffvhB8+fP17x58zRy5MgC/3wAgJuLK+WaxTAMN76lFwDgbpKTkxUcHKz1v0SoaGD+r8GmnLeqWe1jSkpKUlBQUJ7ea7FY9MUXX6hbt27mvt69e8vb21sffPDBFd+TlJSkkiVLauHChbr33nslSb/99puioqKUkJCg22+/XV9//bU6d+6sY8eOKSwsTJI0a9YsDR8+XKdOnZKPj8/1fVgAgNMh13Kfa4y8AgBckqOnVyUnJ9tsqampeW6T1WrVsmXLVK1aNUVHR6tUqVJq3LixzRSsrVu3Kj09XW3btjX31ahRQ+XKlVNCQoIkKSEhQXXq1DEDXpKio6OVnJysXbt2XedPDADgzMg1++i8AgAgKTIyUsHBweY2ceLEPNdx8uRJpaSk6JVXXlGHDh307bff6p577lH37t21Zs0aSVJiYqJ8fHwUEhJi896wsDAlJiaaZf4Z8NnHs48BAGCPO+Yaqw0DAFxSpjyU6YBrsJn/++/Ro0dtplf5+vrmuS6r1SpJ6tq1q5566ilJUr169fTDDz9o1qxZatmyZb7bCwBwT+SafYy8AgBckmFYZHXAZhhZ06uCgoJstusJ+RIlSsjLy0s1a9a02R8VFWWuyhgeHq60tDSdO3fOpsyJEycUHh5ulvn3Ko3Zr7PLAADcC7lmH51XAAAcxMfHR40aNdKePXts9u/du1fly5eXJDVo0EDe3t5atWqVeXzPnj06cuSImjRpIklq0qSJfv75Z508edIss2LFCgUFBeX4AgEAQEFxtlxj2jAAwCUV1sPcU1JStH//fvP1oUOHtH37doWGhqpcuXIaNmyYevXqpRYtWqh169Zavny5lixZotWrV0uSgoODFRcXp6efflqhoaEKCgrS4MGD1aRJE91+++2SpPbt26tmzZp64IEHNGnSJCUmJuqll17SwIEDr+vKOQDA+ZFr9vGoHACAS8l+pMDXOysqwAGPFLhw3qqOtxzK9SMFVq9erdatW+fYHxsbq3nz5kmS3n//fU2cOFF//PGHqlevrjFjxqhr165m2cuXL+uZZ57RRx99pNTUVEVHR+vtt9+2mTp1+PBhPfbYY1q9erUCAgIUGxurV155RV5eXHcGAHdCruU+1+i8AgBcSmGHPAAAjkSu5R6XbwEALskqi6wOWLrBKq7hAgAKH7lmHws2AQAAAACcHiOvAACXVFgLWwAAUBDINfvovAIAXFKm4aFMwwEPc2fpBwCAEyDX7GPaMAAAAADA6THyCgBwSVkLW+R/apQj6gAAIL/INfsYeQVQIObNmyeLxaLff//dbtkKFSqoX79+Bd4muBerPJTpgM0RKzsCAJBf5Jp97vvJABSIXbt2qW/fvipTpox8fX0VERGhmJgY7dq1q7CbBgDAVWVfVPXz89Off/6Z43irVq1Uu3btQmgZgNyi8wog1z7//HPVr19fq1at0kMPPaS3335bcXFx+v7771W/fn198cUXhd1E3ESyF7ZwxAbg5pGamqpXXnmlsJsB5ECu2cc9rwBy5cCBA3rggQdUqVIlrV27ViVLljSPPfHEE2revLkeeOAB7dy5U5UqVSrEluJmYXXQ1Ch3fpg7gJzq1aund999V88//7wiIiIKuzmAiVyzz3275QAcavLkybp48aJmz55t03GVpBIlSuidd97RhQsXNGnSpKvWYRiGxo8fr7Jly6pIkSJq3bo1040BADfUCy+8oMzMzFyNvn744Ydq0KCB/P39FRoaqt69e+vo0aPm8WnTpsnT01Pnzp0z902ZMkUWi0VPP/20uS8zM1OBgYEaPny4Qz8LcLOh8wogV5YsWaIKFSqoefPmVzzeokULVahQQcuWLbtqHSNHjtSIESNUt25dTZ48WZUqVVL79u114cKFgmo23FimYXHYBuDmUbFiRT344IN69913dezYsauWmzBhgh588EFVrVpVr7/+up588kmtWrVKLVq0MDurzZs3l9Vq1fr16833rVu3Th4eHlq3bp2576efflJKSopatGhRYJ8Lro9cs4/OKwC7kpKSdOzYMdWtW/ea5W655Rb98ccfOn/+fI5jp06d0qRJk9SpUyctXbpUAwcO1Jw5c9SvXz/99ddfBdV0AAByePHFF5WRkaFXX331iscPHz6sUaNGafz48fr444/12GOPaeTIkfr+++/1xx9/6O2335Yk1a1bV0FBQWZH1TAMrV+/Xj169DA7rNLfHdo77rjjxnxAwE3ReQVgV3ZnNDAw8Jrlso8nJyfnOLZy5UqlpaVp8ODBslj+viL45JNPOq6huKk44nEC2RuAm0ulSpX0wAMPaPbs2Tp+/HiO459//rmsVqt69uypv/76y9zCw8NVtWpVff/995IkDw8PNW3aVGvXrpUk7d69W6dPn9Zzzz0nwzCUkJAgKavzWrt2bYWEhNywzwjXQ67Z576fDIDDZHdKrzSi+k/X6uQePnxYklS1alWb/SVLllSxYsUc0UzcZKyGh8M2ADefl156SRkZGVe893Xfvn0yDENVq1ZVyZIlbbbdu3fr5MmTZtnmzZtr69atunTpktatW6fSpUurfv36qlu3rjkiu379+qvedgNkI9fsY7VhAHYFBwerdOnS2rlz5zXL7dy5U2XKlFFQUNANahkAANenUqVK6tu3r2bPnq3nnnvO5pjVapXFYtHXX38tT0/PHO8tWrSo+f/NmjVTenq6EhIStG7dOrOT2rx5c61bt06//fabTp06RecVcAA6rwBypXPnznr33Xe1fv16NWvWLMfxdev
2024-12-08 22:49:06 +04:00
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False)\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Below Average\", \"Above Average\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(optimized_metrics.index[index]) \n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-12-13 23:27:39 +04:00
"В желтом квадрате мы видим значение 1732, что обозначает количество правильно классифицированных объектов, отнесенных к классу \"Below Average\". Это свидетельствует о том, что модель успешно идентифицирует объекты этого класса, минимизируя количество ложных положительных срабатываний.\n",
2024-12-08 22:49:06 +04:00
"\n",
2024-12-13 23:27:39 +04:00
"В зеленом квадрате значение 1256 указывает на количество правильно классифицированных объектов, отнесенных к классу \"Above Average\". Это также является показателем высокой точности модели в определении объектов данного класса."
2024-12-08 22:49:06 +04:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Определение достижимого уровня качества модели для второй задачи (задача регрессии)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Загрузка данных и создание целевой переменной"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 25,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Среднее значение поля 'city_latitude': 38.70460772283549\n",
" summary city state \\\n",
"0 Viewed some red lights in the sky appearing to... Visalia CA \n",
"1 Look like 1 or 3 crafts from North traveling s... Cincinnati OH \n",
"2 seen dark rectangle moving slowly thru the sky... Tecopa CA \n",
"3 One red light moving switly west to east, beco... Knoxville TN \n",
"4 Bright, circular Fresnel-lens shaped light sev... Alexandria VA \n",
"\n",
" date_time shape duration \\\n",
"0 2021-12-15T21:45:00 light 2 minutes \n",
"1 2021-12-16T09:45:00 triangle 14 seconds \n",
"2 2021-12-10T00:00:00 rectangle Several minutes \n",
"3 2021-12-10T19:30:00 triangle 20-30 seconds \n",
"4 2021-12-07T08:00:00 circle NaN \n",
"\n",
" stats \\\n",
"0 Occurred : 12/15/2021 21:45 (Entered as : 12/... \n",
"1 Occurred : 12/16/2021 09:45 (Entered as : 12/... \n",
"2 Occurred : 12/10/2021 00:00 (Entered as : 12/... \n",
"3 Occurred : 12/10/2021 19:30 (Entered as : 12/... \n",
"4 Occurred : 12/7/2021 08:00 (Entered as : 12/0... \n",
"\n",
" report_link \\\n",
"0 http://www.nuforc.org/webreports/165/S165881.html \n",
"1 http://www.nuforc.org/webreports/165/S165888.html \n",
"2 http://www.nuforc.org/webreports/165/S165810.html \n",
"3 http://www.nuforc.org/webreports/165/S165825.html \n",
"4 http://www.nuforc.org/webreports/165/S165754.html \n",
"\n",
" text posted \\\n",
"0 Viewed some red lights in the sky appearing to... 2021-12-19T00:00:00 \n",
"1 Look like 1 or 3 crafts from North traveling s... 2021-12-19T00:00:00 \n",
"2 seen dark rectangle moving slowly thru the sky... 2021-12-19T00:00:00 \n",
"3 One red light moving switly west to east, beco... 2021-12-19T00:00:00 \n",
"4 Bright, circular Fresnel-lens shaped light sev... 2021-12-19T00:00:00 \n",
"\n",
" city_latitude city_longitude above_average_city_latitude \n",
"0 36.356650 -119.347937 0 \n",
"1 39.174503 -84.481363 1 \n",
"2 NaN NaN 0 \n",
"3 35.961561 -83.980115 0 \n",
"4 38.798958 -77.095133 1 \n",
"Статистическое описание DataFrame:\n",
" city_latitude city_longitude above_average_city_latitude\n",
"count 110136.000000 110136.000000 136940.000000\n",
"mean 38.704608 -95.185792 0.435928\n",
"std 5.752186 18.310088 0.495880\n",
"min -32.055500 -170.494000 0.000000\n",
"25% 34.238375 -113.901810 0.000000\n",
"50% 39.257500 -89.161450 0.000000\n",
"75% 42.317739 -80.363444 1.000000\n",
"max 64.845276 130.850580 1.000000\n"
]
}
],
"source": [
"import pandas as pd\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"# Загрузка данных\n",
2024-12-13 23:27:39 +04:00
"df = pd.read_csv(\"../../datasets/nuforc_reports.csv\")\n",
2024-12-08 22:49:06 +04:00
"\n",
"# Опция для настройки генерации случайных чисел \n",
"random_state = 42\n",
"\n",
"# Вычисление среднего значения поля \"city_latitude\"\n",
"average_city_latitude = df['city_latitude'].mean()\n",
"print(f\"Среднее значение поля 'city_latitude': {average_city_latitude}\")\n",
"\n",
"# Создание новой колонки, указывающей, выше или ниже среднего значение цены\n",
"df['above_average_city_latitude'] = (df['city_latitude'] > average_city_latitude).astype(int)\n",
"\n",
"# Вывод DataFrame с новой колонкой\n",
"print(df.head())\n",
"\n",
"# Примерный анализ данных\n",
"print(\"Статистическое описание DataFrame:\")\n",
"print(df.describe())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи регрессии\n",
"\n",
"Целевой признак -- above_average_city_latitude"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 26,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>summary</th>\n",
" <th>city</th>\n",
" <th>state</th>\n",
" <th>date_time</th>\n",
" <th>shape</th>\n",
" <th>duration</th>\n",
" <th>stats</th>\n",
" <th>report_link</th>\n",
" <th>text</th>\n",
" <th>posted</th>\n",
" <th>city_latitude</th>\n",
" <th>city_longitude</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>124678</th>\n",
" <td>4 lights moving in unison at a high rate of sp...</td>\n",
" <td>Mount Juliet</td>\n",
" <td>TN</td>\n",
" <td>2019-02-24T20:25:00</td>\n",
" <td>light</td>\n",
" <td>4 seconds</td>\n",
" <td>Occurred : 2/24/2019 20:25 (Entered as : 02/2...</td>\n",
" <td>http://www.nuforc.org/webreports/144/S144994.html</td>\n",
" <td>4 lights moving in unison at a high rate of sp...</td>\n",
" <td>2019-02-27T00:00:00</td>\n",
" <td>36.172148</td>\n",
" <td>-86.490748</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2212</th>\n",
" <td>Observed a long line of objects moving at very...</td>\n",
" <td>Bethany Beach</td>\n",
" <td>DE</td>\n",
" <td>2020-03-01T05:30:00</td>\n",
" <td>other</td>\n",
" <td>45 seconds</td>\n",
" <td>Occurred : 3/1/2020 05:30 (Entered as : 03/01...</td>\n",
" <td>http://www.nuforc.org/webreports/153/S153730.html</td>\n",
" <td>Observed a long line of objects moving at very...</td>\n",
" <td>2020-04-09T00:00:00</td>\n",
" <td>38.556200</td>\n",
" <td>-75.069200</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31960</th>\n",
" <td>My wife and myself were traveling to a friends...</td>\n",
" <td>Webb</td>\n",
" <td>MS</td>\n",
" <td>1975-12-31T19:30:00</td>\n",
" <td>disk</td>\n",
" <td>approx 5 mins</td>\n",
" <td>Occurred : 12/31/1975 19:30 (Entered as : 12/...</td>\n",
" <td>http://www.nuforc.org/webreports/037/S37355.html</td>\n",
" <td>My wife and myself were traveling to a friends...</td>\n",
" <td>2004-06-18T00:00:00</td>\n",
" <td>33.919100</td>\n",
" <td>-90.307300</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33954</th>\n",
" <td>I was sitting at my friends house, in my car a...</td>\n",
" <td>Ellensburg</td>\n",
" <td>WA</td>\n",
" <td>2004-03-17T08:45:00</td>\n",
" <td>diamond</td>\n",
" <td>10-20 minutes</td>\n",
" <td>Occurred : 3/17/2004 08:45 (Entered as : 03/1...</td>\n",
" <td>http://www.nuforc.org/webreports/035/S35995.html</td>\n",
" <td>I was sitting at my friends house, in my car a...</td>\n",
" <td>2004-04-09T00:00:00</td>\n",
" <td>46.979000</td>\n",
" <td>-120.470300</td>\n",
" </tr>\n",
" <tr>\n",
" <th>516</th>\n",
" <td>Observed two glimmering craft over Powell Rive...</td>\n",
" <td>Powell River</td>\n",
" <td>BC</td>\n",
" <td>2020-04-09T11:00:00</td>\n",
" <td>disk</td>\n",
" <td>3 minutes</td>\n",
" <td>Occurred : 4/9/2020 11:00 (Entered as : 04/09...</td>\n",
" <td>http://www.nuforc.org/webreports/155/S155845.html</td>\n",
" <td>Observed two glimmering craft over Powell Rive...</td>\n",
" <td>2020-06-25T00:00:00</td>\n",
" <td>50.016300</td>\n",
" <td>-124.322600</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>110268</th>\n",
" <td>Big black tranparent blib in sky.</td>\n",
" <td>Sedro Woolley</td>\n",
" <td>WA</td>\n",
" <td>2016-08-14T19:45:00</td>\n",
" <td>other</td>\n",
" <td>6 seconds</td>\n",
" <td>Occurred : 8/14/2016 19:45 (Entered as : 08/1...</td>\n",
" <td>http://www.nuforc.org/webreports/129/S129281.html</td>\n",
" <td>Big black tranparent blib in sky Walking at No...</td>\n",
" <td>2016-08-16T00:00:00</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>119879</th>\n",
" <td>Very fast bright light.</td>\n",
" <td>Brainerd</td>\n",
" <td>MN</td>\n",
" <td>2017-08-19T18:00:00</td>\n",
" <td>light</td>\n",
" <td>5 seconds</td>\n",
" <td>Occurred : 8/19/2017 18:00 (Entered as : 08/1...</td>\n",
" <td>http://www.nuforc.org/webreports/135/S135876.html</td>\n",
" <td>Very fast bright light. Observed a bright ligh...</td>\n",
" <td>2017-08-24T00:00:00</td>\n",
" <td>46.306700</td>\n",
" <td>-94.100800</td>\n",
" </tr>\n",
" <tr>\n",
" <th>103694</th>\n",
" <td>never seen this before</td>\n",
" <td>Old Tappan</td>\n",
" <td>NJ</td>\n",
" <td>2015-03-12T05:49:00</td>\n",
" <td>triangle</td>\n",
" <td>20 seconds</td>\n",
" <td>Occurred : 3/12/2015 05:49 (Entered as : 0312...</td>\n",
" <td>http://www.nuforc.org/webreports/117/S117744.html</td>\n",
" <td>never seen this before. moved slow then all th...</td>\n",
" <td>2015-03-13T00:00:00</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>131932</th>\n",
" <td>Back in the late 70s or early 80s my family wa...</td>\n",
" <td>Tacoma</td>\n",
" <td>WA</td>\n",
" <td>1980-06-01T09:00:00</td>\n",
" <td>sphere</td>\n",
" <td>10 seconds</td>\n",
" <td>Occurred : 6/1/1980 09:00 (Entered as : 06/01...</td>\n",
" <td>http://www.nuforc.org/webreports/164/S164812.html</td>\n",
" <td>Back in the late 70s or early 80s my family wa...</td>\n",
" <td>2021-10-19T00:00:00</td>\n",
" <td>47.212572</td>\n",
" <td>-122.459720</td>\n",
" </tr>\n",
" <tr>\n",
" <th>121958</th>\n",
" <td>Exiting Highway 61 onto I-280 and saw a light ...</td>\n",
" <td>Davenport</td>\n",
" <td>IA</td>\n",
" <td>2018-11-01T01:00:00</td>\n",
" <td>light</td>\n",
" <td>2 seconds</td>\n",
" <td>Occurred : 11/1/2018 01:00 (Entered as : 11/0...</td>\n",
" <td>http://www.nuforc.org/webreports/143/S143647.html</td>\n",
" <td>Exiting highway 61 on to Interstate 280 and sa...</td>\n",
" <td>2018-11-09T00:00:00</td>\n",
" <td>41.555164</td>\n",
" <td>-90.598760</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>109552 rows × 12 columns</p>\n",
"</div>"
],
"text/plain": [
" summary city \\\n",
"124678 4 lights moving in unison at a high rate of sp... Mount Juliet \n",
"2212 Observed a long line of objects moving at very... Bethany Beach \n",
"31960 My wife and myself were traveling to a friends... Webb \n",
"33954 I was sitting at my friends house, in my car a... Ellensburg \n",
"516 Observed two glimmering craft over Powell Rive... Powell River \n",
"... ... ... \n",
"110268 Big black tranparent blib in sky. Sedro Woolley \n",
"119879 Very fast bright light. Brainerd \n",
"103694 never seen this before Old Tappan \n",
"131932 Back in the late 70s or early 80s my family wa... Tacoma \n",
"121958 Exiting Highway 61 onto I-280 and saw a light ... Davenport \n",
"\n",
" state date_time shape duration \\\n",
"124678 TN 2019-02-24T20:25:00 light 4 seconds \n",
"2212 DE 2020-03-01T05:30:00 other 45 seconds \n",
"31960 MS 1975-12-31T19:30:00 disk approx 5 mins \n",
"33954 WA 2004-03-17T08:45:00 diamond 10-20 minutes \n",
"516 BC 2020-04-09T11:00:00 disk 3 minutes \n",
"... ... ... ... ... \n",
"110268 WA 2016-08-14T19:45:00 other 6 seconds \n",
"119879 MN 2017-08-19T18:00:00 light 5 seconds \n",
"103694 NJ 2015-03-12T05:49:00 triangle 20 seconds \n",
"131932 WA 1980-06-01T09:00:00 sphere 10 seconds \n",
"121958 IA 2018-11-01T01:00:00 light 2 seconds \n",
"\n",
" stats \\\n",
"124678 Occurred : 2/24/2019 20:25 (Entered as : 02/2... \n",
"2212 Occurred : 3/1/2020 05:30 (Entered as : 03/01... \n",
"31960 Occurred : 12/31/1975 19:30 (Entered as : 12/... \n",
"33954 Occurred : 3/17/2004 08:45 (Entered as : 03/1... \n",
"516 Occurred : 4/9/2020 11:00 (Entered as : 04/09... \n",
"... ... \n",
"110268 Occurred : 8/14/2016 19:45 (Entered as : 08/1... \n",
"119879 Occurred : 8/19/2017 18:00 (Entered as : 08/1... \n",
"103694 Occurred : 3/12/2015 05:49 (Entered as : 0312... \n",
"131932 Occurred : 6/1/1980 09:00 (Entered as : 06/01... \n",
"121958 Occurred : 11/1/2018 01:00 (Entered as : 11/0... \n",
"\n",
" report_link \\\n",
"124678 http://www.nuforc.org/webreports/144/S144994.html \n",
"2212 http://www.nuforc.org/webreports/153/S153730.html \n",
"31960 http://www.nuforc.org/webreports/037/S37355.html \n",
"33954 http://www.nuforc.org/webreports/035/S35995.html \n",
"516 http://www.nuforc.org/webreports/155/S155845.html \n",
"... ... \n",
"110268 http://www.nuforc.org/webreports/129/S129281.html \n",
"119879 http://www.nuforc.org/webreports/135/S135876.html \n",
"103694 http://www.nuforc.org/webreports/117/S117744.html \n",
"131932 http://www.nuforc.org/webreports/164/S164812.html \n",
"121958 http://www.nuforc.org/webreports/143/S143647.html \n",
"\n",
" text \\\n",
"124678 4 lights moving in unison at a high rate of sp... \n",
"2212 Observed a long line of objects moving at very... \n",
"31960 My wife and myself were traveling to a friends... \n",
"33954 I was sitting at my friends house, in my car a... \n",
"516 Observed two glimmering craft over Powell Rive... \n",
"... ... \n",
"110268 Big black tranparent blib in sky Walking at No... \n",
"119879 Very fast bright light. Observed a bright ligh... \n",
"103694 never seen this before. moved slow then all th... \n",
"131932 Back in the late 70s or early 80s my family wa... \n",
"121958 Exiting highway 61 on to Interstate 280 and sa... \n",
"\n",
" posted city_latitude city_longitude \n",
"124678 2019-02-27T00:00:00 36.172148 -86.490748 \n",
"2212 2020-04-09T00:00:00 38.556200 -75.069200 \n",
"31960 2004-06-18T00:00:00 33.919100 -90.307300 \n",
"33954 2004-04-09T00:00:00 46.979000 -120.470300 \n",
"516 2020-06-25T00:00:00 50.016300 -124.322600 \n",
"... ... ... ... \n",
"110268 2016-08-16T00:00:00 NaN NaN \n",
"119879 2017-08-24T00:00:00 46.306700 -94.100800 \n",
"103694 2015-03-13T00:00:00 NaN NaN \n",
"131932 2021-10-19T00:00:00 47.212572 -122.459720 \n",
"121958 2018-11-09T00:00:00 41.555164 -90.598760 \n",
"\n",
"[109552 rows x 12 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>above_average_city_latitude</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>124678</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2212</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31960</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33954</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>516</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>110268</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>119879</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>103694</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>131932</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>121958</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>109552 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" above_average_city_latitude\n",
"124678 0\n",
"2212 0\n",
"31960 0\n",
"33954 1\n",
"516 1\n",
"... ...\n",
"110268 0\n",
"119879 1\n",
"103694 0\n",
"131932 1\n",
"121958 1\n",
"\n",
"[109552 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>summary</th>\n",
" <th>city</th>\n",
" <th>state</th>\n",
" <th>date_time</th>\n",
" <th>shape</th>\n",
" <th>duration</th>\n",
" <th>stats</th>\n",
" <th>report_link</th>\n",
" <th>text</th>\n",
" <th>posted</th>\n",
" <th>city_latitude</th>\n",
" <th>city_longitude</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>108340</th>\n",
" <td>Bright light with irregular flight pattern and...</td>\n",
" <td>Pittsboro</td>\n",
" <td>NC</td>\n",
" <td>2015-12-08T21:00:00</td>\n",
" <td>light</td>\n",
" <td>5 minutes</td>\n",
" <td>Occurred : 12/8/2015 21:00 (Entered as : 12/0...</td>\n",
" <td>http://www.nuforc.org/webreports/124/S124610.html</td>\n",
" <td>Bright light with irregular flight pattern and...</td>\n",
" <td>2015-12-17T00:00:00</td>\n",
" <td>35.751900</td>\n",
" <td>-79.224800</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29071</th>\n",
" <td>3 lights flash less than seconds apart in a ro...</td>\n",
" <td>San Luis Obispo</td>\n",
" <td>CA</td>\n",
" <td>2004-03-16T22:10:00</td>\n",
" <td>flash</td>\n",
" <td>3 seconds</td>\n",
" <td>Occurred : 3/16/2004 22:10 (Entered as : 03/1...</td>\n",
" <td>http://www.nuforc.org/webreports/035/S35640.html</td>\n",
" <td>3 lights flash less than seconds apart in a ro...</td>\n",
" <td>2004-03-17T00:00:00</td>\n",
" <td>35.262867</td>\n",
" <td>-120.624789</td>\n",
" </tr>\n",
" <tr>\n",
" <th>89462</th>\n",
" <td>A wonderful aircraft, but spooky.</td>\n",
" <td>Cornwall</td>\n",
" <td>ON</td>\n",
" <td>2013-07-06T00:30:00</td>\n",
" <td>triangle</td>\n",
" <td>20 seconds</td>\n",
" <td>Occurred : 7/6/2013 00:30 (Entered as : 76201...</td>\n",
" <td>http://www.nuforc.org/webreports/099/S99896.html</td>\n",
" <td>A wonderful aircraft , but spooky Saw aircraft...</td>\n",
" <td>2013-07-14T00:00:00</td>\n",
" <td>45.056209</td>\n",
" <td>-74.710143</td>\n",
" </tr>\n",
" <tr>\n",
" <th>124422</th>\n",
" <td>I was outside facing east, speaking with a co-...</td>\n",
" <td>Virginia Beach</td>\n",
" <td>VA</td>\n",
" <td>2019-01-09T06:35:00</td>\n",
" <td>fireball</td>\n",
" <td>5-7 seconds</td>\n",
" <td>Occurred : 1/9/2019 06:35 (Entered as : 1-9-1...</td>\n",
" <td>http://www.nuforc.org/webreports/144/S144555.html</td>\n",
" <td>I was outside facing east, speaking with a co-...</td>\n",
" <td>2019-01-24T00:00:00</td>\n",
" <td>36.837301</td>\n",
" <td>-76.061948</td>\n",
" </tr>\n",
" <tr>\n",
" <th>126342</th>\n",
" <td>Bright circular light in the sky - once we sta...</td>\n",
" <td>Fruitland</td>\n",
" <td>UT</td>\n",
" <td>2019-05-11T23:40:00</td>\n",
" <td>circle</td>\n",
" <td>2 minutes</td>\n",
" <td>Occurred : 5/11/2019 23:40 (Entered as : 5/11...</td>\n",
" <td>http://www.nuforc.org/webreports/146/S146570.html</td>\n",
" <td>Bright circular light in the sky - once we sta...</td>\n",
" <td>2019-06-07T00:00:00</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48714</th>\n",
" <td>Was this white cylinder with fins a weather ba...</td>\n",
" <td>Cape Coral</td>\n",
" <td>FL</td>\n",
" <td>2007-04-29T15:30:00</td>\n",
" <td>cylinder</td>\n",
" <td>30 mins</td>\n",
" <td>Occurred : 4/29/2007 15:30 (Entered as : 04/2...</td>\n",
" <td>http://www.nuforc.org/webreports/056/S56462.html</td>\n",
" <td>Was this white cylinder with fins a weather ba...</td>\n",
" <td>2007-06-12T00:00:00</td>\n",
" <td>26.616422</td>\n",
" <td>-81.970066</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25512</th>\n",
" <td>I was driving down a rural road near Haymarket...</td>\n",
" <td>Haymarket</td>\n",
" <td>VA</td>\n",
" <td>2003-03-11T22:00:00</td>\n",
" <td>triangle</td>\n",
" <td>3 minutes</td>\n",
" <td>Occurred : 3/11/2003 22:00 (Entered as : 3-11...</td>\n",
" <td>http://www.nuforc.org/webreports/028/S28125.html</td>\n",
" <td>I was driving down a rural road near Haymarket...</td>\n",
" <td>2003-03-21T00:00:00</td>\n",
" <td>38.869400</td>\n",
" <td>-77.637300</td>\n",
" </tr>\n",
" <tr>\n",
" <th>96155</th>\n",
" <td>Five orange lights in the sky over Essex area,...</td>\n",
" <td>Essex Junction</td>\n",
" <td>VT</td>\n",
" <td>2014-05-25T21:50:00</td>\n",
" <td>light</td>\n",
" <td>10 minutes</td>\n",
" <td>Occurred : 5/25/2014 21:50 (Entered as : 05/2...</td>\n",
" <td>http://www.nuforc.org/webreports/109/S109718.html</td>\n",
" <td>Five orange lights in the sky over Essex area,...</td>\n",
" <td>2014-06-04T00:00:00</td>\n",
" <td>44.532199</td>\n",
" <td>-73.058631</td>\n",
" </tr>\n",
" <tr>\n",
" <th>82188</th>\n",
" <td>It was slowly moving over our pastures going a...</td>\n",
" <td>New Washington</td>\n",
" <td>OH</td>\n",
" <td>2013-01-24T21:00:00</td>\n",
" <td>circle</td>\n",
" <td>5 minutes</td>\n",
" <td>Occurred : 1/24/2013 21:00 (Entered as : 1/24...</td>\n",
" <td>http://www.nuforc.org/webreports/096/S96095.html</td>\n",
" <td>It was slowly moving over our pastures going a...</td>\n",
" <td>2013-02-04T00:00:00</td>\n",
" <td>40.945000</td>\n",
" <td>-82.861800</td>\n",
" </tr>\n",
" <tr>\n",
" <th>124520</th>\n",
" <td>2 circles in the sky changed colors several ti...</td>\n",
" <td>South Amboy</td>\n",
" <td>NJ</td>\n",
" <td>2019-02-02T00:28:00</td>\n",
" <td>circle</td>\n",
" <td>1 minute</td>\n",
" <td>Occurred : 2/2/2019 00:28 (Entered as : 02/2/...</td>\n",
" <td>http://www.nuforc.org/webreports/144/S144752.html</td>\n",
" <td>2 circles in the sky changed colors several ti...</td>\n",
" <td>2019-02-07T00:00:00</td>\n",
" <td>40.477900</td>\n",
" <td>-74.290700</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>27388 rows × 12 columns</p>\n",
"</div>"
],
"text/plain": [
" summary city \\\n",
"108340 Bright light with irregular flight pattern and... Pittsboro \n",
"29071 3 lights flash less than seconds apart in a ro... San Luis Obispo \n",
"89462 A wonderful aircraft, but spooky. Cornwall \n",
"124422 I was outside facing east, speaking with a co-... Virginia Beach \n",
"126342 Bright circular light in the sky - once we sta... Fruitland \n",
"... ... ... \n",
"48714 Was this white cylinder with fins a weather ba... Cape Coral \n",
"25512 I was driving down a rural road near Haymarket... Haymarket \n",
"96155 Five orange lights in the sky over Essex area,... Essex Junction \n",
"82188 It was slowly moving over our pastures going a... New Washington \n",
"124520 2 circles in the sky changed colors several ti... South Amboy \n",
"\n",
" state date_time shape duration \\\n",
"108340 NC 2015-12-08T21:00:00 light 5 minutes \n",
"29071 CA 2004-03-16T22:10:00 flash 3 seconds \n",
"89462 ON 2013-07-06T00:30:00 triangle 20 seconds \n",
"124422 VA 2019-01-09T06:35:00 fireball 5-7 seconds \n",
"126342 UT 2019-05-11T23:40:00 circle 2 minutes \n",
"... ... ... ... ... \n",
"48714 FL 2007-04-29T15:30:00 cylinder 30 mins \n",
"25512 VA 2003-03-11T22:00:00 triangle 3 minutes \n",
"96155 VT 2014-05-25T21:50:00 light 10 minutes \n",
"82188 OH 2013-01-24T21:00:00 circle 5 minutes \n",
"124520 NJ 2019-02-02T00:28:00 circle 1 minute \n",
"\n",
" stats \\\n",
"108340 Occurred : 12/8/2015 21:00 (Entered as : 12/0... \n",
"29071 Occurred : 3/16/2004 22:10 (Entered as : 03/1... \n",
"89462 Occurred : 7/6/2013 00:30 (Entered as : 76201... \n",
"124422 Occurred : 1/9/2019 06:35 (Entered as : 1-9-1... \n",
"126342 Occurred : 5/11/2019 23:40 (Entered as : 5/11... \n",
"... ... \n",
"48714 Occurred : 4/29/2007 15:30 (Entered as : 04/2... \n",
"25512 Occurred : 3/11/2003 22:00 (Entered as : 3-11... \n",
"96155 Occurred : 5/25/2014 21:50 (Entered as : 05/2... \n",
"82188 Occurred : 1/24/2013 21:00 (Entered as : 1/24... \n",
"124520 Occurred : 2/2/2019 00:28 (Entered as : 02/2/... \n",
"\n",
" report_link \\\n",
"108340 http://www.nuforc.org/webreports/124/S124610.html \n",
"29071 http://www.nuforc.org/webreports/035/S35640.html \n",
"89462 http://www.nuforc.org/webreports/099/S99896.html \n",
"124422 http://www.nuforc.org/webreports/144/S144555.html \n",
"126342 http://www.nuforc.org/webreports/146/S146570.html \n",
"... ... \n",
"48714 http://www.nuforc.org/webreports/056/S56462.html \n",
"25512 http://www.nuforc.org/webreports/028/S28125.html \n",
"96155 http://www.nuforc.org/webreports/109/S109718.html \n",
"82188 http://www.nuforc.org/webreports/096/S96095.html \n",
"124520 http://www.nuforc.org/webreports/144/S144752.html \n",
"\n",
" text \\\n",
"108340 Bright light with irregular flight pattern and... \n",
"29071 3 lights flash less than seconds apart in a ro... \n",
"89462 A wonderful aircraft , but spooky Saw aircraft... \n",
"124422 I was outside facing east, speaking with a co-... \n",
"126342 Bright circular light in the sky - once we sta... \n",
"... ... \n",
"48714 Was this white cylinder with fins a weather ba... \n",
"25512 I was driving down a rural road near Haymarket... \n",
"96155 Five orange lights in the sky over Essex area,... \n",
"82188 It was slowly moving over our pastures going a... \n",
"124520 2 circles in the sky changed colors several ti... \n",
"\n",
" posted city_latitude city_longitude \n",
"108340 2015-12-17T00:00:00 35.751900 -79.224800 \n",
"29071 2004-03-17T00:00:00 35.262867 -120.624789 \n",
"89462 2013-07-14T00:00:00 45.056209 -74.710143 \n",
"124422 2019-01-24T00:00:00 36.837301 -76.061948 \n",
"126342 2019-06-07T00:00:00 NaN NaN \n",
"... ... ... ... \n",
"48714 2007-06-12T00:00:00 26.616422 -81.970066 \n",
"25512 2003-03-21T00:00:00 38.869400 -77.637300 \n",
"96155 2014-06-04T00:00:00 44.532199 -73.058631 \n",
"82188 2013-02-04T00:00:00 40.945000 -82.861800 \n",
"124520 2019-02-07T00:00:00 40.477900 -74.290700 \n",
"\n",
"[27388 rows x 12 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>above_average_city_latitude</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>108340</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29071</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>89462</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>124422</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>126342</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48714</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25512</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>96155</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>82188</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>124520</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>27388 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" above_average_city_latitude\n",
"108340 0\n",
"29071 0\n",
"89462 1\n",
"124422 0\n",
"126342 0\n",
"... ...\n",
"48714 0\n",
"25512 1\n",
"96155 1\n",
"82188 1\n",
"124520 1\n",
"\n",
"[27388 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_into_train_test(\n",
" df_input: DataFrame,\n",
" target_colname: str = \"above_average_city_latitude\", \n",
" frac_train: float = 0.8,\n",
" random_state: int = None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if not (0 < frac_train < 1):\n",
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
" \n",
" # Проверка наличия целевого признака\n",
" if target_colname not in df_input.columns:\n",
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
" \n",
" # Разделяем данные на признаки и целевую переменную\n",
" X = df_input.drop(columns=[target_colname]) # Признаки\n",
" y = df_input[[target_colname]] # Целевая переменная\n",
"\n",
" # Разделяем данные на обучающую и тестовую выборки\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y,\n",
" test_size=(1.0 - frac_train),\n",
" random_state=random_state\n",
" )\n",
" \n",
" return X_train, X_test, y_train, y_test\n",
"\n",
"# Применение функции для разделения данных\n",
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
" df, \n",
" target_colname=\"above_average_city_latitude\", \n",
" frac_train=0.8, \n",
" random_state=42 \n",
")\n",
"\n",
"# Для отображения результатов\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование конвейера для решения задачи регрессии"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 27,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" city_latitude city_longitude shape_changing shape_chevron \\\n",
"0 -0.475704 -1.527173 0.0 0.0 \n",
"1 0.070048 0.574031 0.0 0.0 \n",
"2 0.086123 0.291990 0.0 0.0 \n",
"3 -0.552224 0.604238 0.0 0.0 \n",
"4 -0.002686 1.019156 0.0 0.0 \n",
"... ... ... ... ... \n",
"136935 1.426621 -1.516027 0.0 0.0 \n",
"136936 -2.044798 0.757950 0.0 0.0 \n",
"136937 0.180564 -0.660057 0.0 0.0 \n",
"136938 -0.719495 -1.453787 0.0 0.0 \n",
"136939 0.785101 1.328590 0.0 0.0 \n",
"\n",
" shape_cigar shape_circle shape_cone shape_crescent shape_cross \\\n",
"0 0.0 0.0 0.0 0.0 0.0 \n",
"1 0.0 0.0 0.0 0.0 0.0 \n",
"2 0.0 0.0 0.0 0.0 0.0 \n",
"3 0.0 0.0 0.0 0.0 0.0 \n",
"4 0.0 1.0 0.0 0.0 0.0 \n",
"... ... ... ... ... ... \n",
"136935 0.0 0.0 0.0 0.0 0.0 \n",
"136936 0.0 0.0 0.0 0.0 0.0 \n",
"136937 0.0 0.0 0.0 0.0 0.0 \n",
"136938 0.0 0.0 0.0 0.0 0.0 \n",
"136939 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" shape_cylinder ... shape_light shape_other shape_oval \\\n",
"0 0.0 ... 1.0 0.0 0.0 \n",
"1 0.0 ... 0.0 0.0 0.0 \n",
"2 0.0 ... 0.0 0.0 0.0 \n",
"3 0.0 ... 0.0 0.0 0.0 \n",
"4 0.0 ... 0.0 0.0 0.0 \n",
"... ... ... ... ... ... \n",
"136935 0.0 ... 0.0 0.0 0.0 \n",
"136936 1.0 ... 0.0 0.0 0.0 \n",
"136937 0.0 ... 1.0 0.0 0.0 \n",
"136938 0.0 ... 0.0 0.0 0.0 \n",
"136939 0.0 ... 0.0 0.0 1.0 \n",
"\n",
" shape_pyramid shape_rectangle shape_round shape_sphere \\\n",
"0 0.0 0.0 0.0 0.0 \n",
"1 0.0 0.0 0.0 0.0 \n",
"2 0.0 1.0 0.0 0.0 \n",
"3 0.0 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 0.0 \n",
"... ... ... ... ... \n",
"136935 0.0 0.0 0.0 0.0 \n",
"136936 0.0 0.0 0.0 0.0 \n",
"136937 0.0 0.0 0.0 0.0 \n",
"136938 0.0 0.0 0.0 0.0 \n",
"136939 0.0 0.0 0.0 0.0 \n",
"\n",
" shape_teardrop shape_triangle shape_unknown \n",
"0 0.0 0.0 0.0 \n",
"1 0.0 1.0 0.0 \n",
"2 0.0 0.0 0.0 \n",
"3 0.0 1.0 0.0 \n",
"4 0.0 0.0 0.0 \n",
"... ... ... ... \n",
"136935 0.0 1.0 0.0 \n",
"136936 0.0 0.0 0.0 \n",
"136937 0.0 0.0 0.0 \n",
"136938 0.0 0.0 0.0 \n",
"136939 0.0 0.0 0.0 \n",
"\n",
"[136940 rows x 30 columns]\n",
"(136940, 30)\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from sklearn.ensemble import RandomForestRegressor \n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.pipeline import make_pipeline\n",
"import pandas as pd\n",
"\n",
"class JioMartFeatures(BaseEstimator, TransformerMixin): \n",
" def __init__(self):\n",
" pass\n",
"\n",
" def fit(self, X, y=None):\n",
" return self\n",
"\n",
" def transform(self, X, y=None):\n",
" if 'state' in X.columns:\n",
" X[\"city_latitude_per_state\"] = X[\"city_latitude\"] / X[\"state\"].nunique()\n",
" return X\n",
"\n",
" def get_feature_names_out(self, features_in):\n",
" return np.append(features_in, [\"city_latitude_per_state\"], axis=0) \n",
"\n",
"# Определите признаки для вашей задачи\n",
"columns_to_drop = [\"date_time\", \"posted\", \"city\", \"state\", \"summary\", \"stats\", \"report_link\", \"duration\", \"text\"] # Столбцы, которые можно удалить\n",
"num_columns = [\"city_latitude\", \"city_longitude\"] # Числовые столбцы\n",
"cat_columns = [\"shape\"] # Категориальные столбцы\n",
"\n",
"# Преобразование числовых признаков\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"# Преобразование категориальных признаков\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"# Формирование конвейера\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\" \n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"# Окончательный конвейер\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" (\"custom_features\", JioMartFeatures()), # Добавляем custom_features\n",
" ]\n",
")\n",
"\n",
"# Загрузка данных\n",
2024-12-13 23:27:39 +04:00
"df = pd.read_csv(\"../../datasets/nuforc_reports.csv\")\n",
2024-12-08 22:49:06 +04:00
"\n",
"# Создаем целевой признак\n",
"average_city_latitude = df['city_latitude'].mean()\n",
"df['above_average_city_latitude'] = (df['city_latitude'] > average_city_latitude).astype(int)\n",
"\n",
"# Подготовка данных\n",
"X = df.drop('above_average_city_latitude', axis=1)\n",
"y = df['above_average_city_latitude'].values.ravel()\n",
"\n",
"# Проверка наличия столбцов перед применением конвейера\n",
"required_columns = set(num_columns + cat_columns + columns_to_drop)\n",
"missing_columns = required_columns - set(X.columns)\n",
"if missing_columns:\n",
" raise KeyError(f\"Missing columns: {missing_columns}\")\n",
"\n",
"# Применение конвейера\n",
"X_processed = pipeline_end.fit_transform(X)\n",
"\n",
"# Вывод\n",
"print(X_processed)\n",
"print(X_processed.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование набора моделей для регрессии"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 28,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"Random Forest: Mean Score = 0.9789167519338394, Standard Deviation = 0.015961406592291463\n",
2024-12-08 22:49:06 +04:00
"Linear Regression: Mean Score = 0.5039253856797983, Standard Deviation = 0.030322793232352978\n",
2024-12-13 23:27:39 +04:00
"Gradient Boosting: Mean Score = 0.9901053617704161, Standard Deviation = 0.008763228065314608\n",
2024-12-08 22:49:06 +04:00
"Support Vector Regression: Mean Score = 0.8080621690604891, Standard Deviation = 0.04395269414319326\n"
]
}
],
"source": [
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.model_selection import cross_val_score\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.ensemble import GradientBoostingRegressor\n",
"from sklearn.svm import SVR\n",
"\n",
"def train_multiple_models(X, y, models, cv=3):\n",
" results = {}\n",
" for model_name, model in models.items():\n",
" # Создаем конвейер для каждой модели\n",
" model_pipeline = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" (\"model\", model) # Используем текущую модель\n",
" ]\n",
" )\n",
" \n",
" # Обучаем модель и вычисляем кросс-валидацию\n",
" scores = cross_val_score(model_pipeline, X, y, cv=cv, n_jobs=-1) # Используем все ядра процессора\n",
" results[model_name] = {\n",
" \"mean_score\": scores.mean(),\n",
" \"std_dev\": scores.std()\n",
" }\n",
" \n",
" return results\n",
"\n",
"# Определение моделей\n",
"models = {\n",
" \"Random Forest\": RandomForestRegressor(n_estimators=10), # Уменьшаем количество деревьев\n",
" \"Linear Regression\": LinearRegression(),\n",
" \"Gradient Boosting\": GradientBoostingRegressor(),\n",
" \"Support Vector Regression\": SVR()\n",
"}\n",
"\n",
"# Используем подвыборку данных\n",
"sample_size = 1000 # Уменьшаем количество данных для обучения\n",
"X_train_sample = X_train.sample(n=sample_size, random_state=42)\n",
"y_train_sample = y_train.loc[X_train_sample.index] # Используем loc для индексации Series\n",
"\n",
"# Обучение моделей и вывод результатов\n",
"results = train_multiple_models(X_train_sample, y_train_sample, models, cv=3) # Уменьшаем количество фолдов\n",
"\n",
"# Вывод результатов\n",
"for model_name, scores in results.items():\n",
" print(f\"{model_name}: Mean Score = {scores['mean_score']}, Standard Deviation = {scores['std_dev']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Модель: Random Forest\n",
2024-12-13 23:27:39 +04:00
"- **Mean Score**: 0.9789167519338394\n",
"- **Standard Deviation**: 0.015961406592291463\n",
2024-12-08 22:49:06 +04:00
"**Описание**:\n",
"- Random Forest показала очень высокое среднее значение, близкое к 1, что указывает на е е высокую точность в предсказании. Стандартное отклонение также относительно низкое, что говорит о стабильности модели.\n",
"\n",
"#### Модель: Linear Regression\n",
2024-12-13 23:27:39 +04:00
"- **Mean Score**: 0.5039253856797983\n",
"- **Standard Deviation**: 0.030322793232352978\n",
2024-12-08 22:49:06 +04:00
"**Описание**:\n",
2024-12-13 23:27:39 +04:00
"- Линейная регрессия показала очень низкое среднее значение, что указывает на е е неэффективность в данной задаче. Стандартное отклонение также очень высокое, что говорит о нестабильности модели.\n",
2024-12-08 22:49:06 +04:00
"\n",
"#### Модель: Gradient Boosting\n",
2024-12-13 23:27:39 +04:00
"- **Mean Score**: 0.9901053617704161\n",
"- **Standard Deviation**: 0.008763228065314608\n",
2024-12-08 22:49:06 +04:00
"**Описание**:\n",
"- Gradient Boosting показала практически идеальное среднее значение, близкое к 1, что указывает на е е высокую точность в предсказании. Стандартное отклонение относительно низкое, что говорит о стабильности модели.\n",
"\n",
"#### Модель: Support Vector Regression\n",
2024-12-13 23:27:39 +04:00
"- **Mean Score**: 0.8080621690604891\n",
"- **Standard Deviation**: 0.04395269414319326\n",
2024-12-08 22:49:06 +04:00
"**Описание**:\n",
"- Support Vector Regression показала среднее значение около 0.64, что указывает на е е умеренную точность в предсказании. Стандартное отклонение относительно низкое, что говорит о стабильности модели, но она все же уступает Random Forest и Gradient Boosting.\n",
"\n",
"\n",
"1. **Random Forest и Gradient Boosting** демонстрируют высокую точность и стабильность, что делает их наиболее подходящими моделями для данной задачи регрессии.\n",
"2. **Linear Regression** неэффективна и нестабильна, что указывает на необходимость е е замены на более подходящую модель.\n",
"3. **Support Vector Regression** показывает умеренную точность и стабильность, но уступает Random Forest и Gradient Boosting в эффективности."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение моделей на обучающем наборе данных и оценка на тестовом для регрессии"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 29,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\pipeline.py:62: FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\pipeline.py:62: FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.\n",
2024-12-08 22:49:06 +04:00
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE (train): 0.04693661457572659\n",
"MSE (test): 0.04651672265225646\n",
"MAE (train): 0.04693661457572659\n",
"MAE (test): 0.04651672265225646\n",
"R2 (train): 0.8090856146775022\n",
"R2 (test): 0.810957714373292\n",
"STD (train): 0.21150311767890387\n",
"STD (test): 0.21060132280199356\n",
"----------------------------------------\n",
"Model: ridge\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\pipeline.py:62: FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\pipeline.py:62: FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.\n",
2024-12-08 22:49:06 +04:00
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE (train): 0.04438987877902731\n",
"MSE (test): 0.04465459325251935\n",
"MAE (train): 0.04438987877902731\n",
"MAE (test): 0.04465459325251935\n",
"R2 (train): 0.8194444465532269\n",
"R2 (test): 0.8185253411919435\n",
"STD (train): 0.20595974713766416\n",
"STD (test): 0.2065443307233859\n",
"----------------------------------------\n",
"Model: decision_tree\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\pipeline.py:62: FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\pipeline.py:62: FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.\n",
2024-12-08 22:49:06 +04:00
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE (train): 0.0\n",
"MSE (test): 0.0\n",
"MAE (train): 0.0\n",
"MAE (test): 0.0\n",
"R2 (train): 1.0\n",
"R2 (test): 1.0\n",
"STD (train): 0.0\n",
"STD (test): 0.0\n",
"----------------------------------------\n",
"Model: knn\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\pipeline.py:62: FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\pipeline.py:62: FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.\n",
2024-12-08 22:49:06 +04:00
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE (train): 0.0024737111143566526\n",
"MSE (test): 0.0035416970936176426\n",
"MAE (train): 0.0024737111143566526\n",
"MAE (test): 0.0035416970936176426\n",
"R2 (train): 0.9899381955615719\n",
"R2 (test): 0.9856066705606039\n",
"STD (train): 0.04973611399311279\n",
"STD (test): 0.05950515838317305\n",
"----------------------------------------\n",
"Model: naive_bayes\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\pipeline.py:62: FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\pipeline.py:62: FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.\n",
2024-12-08 22:49:06 +04:00
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE (train): 0.12056375054768512\n",
"MSE (test): 0.12202424419453775\n",
"MAE (train): 0.12056375054768512\n",
"MAE (test): 0.12202424419453775\n",
"R2 (train): 0.5096077010230355\n",
"R2 (test): 0.5040978661189495\n",
"STD (train): 0.32835683006894384\n",
"STD (test): 0.33012146600139913\n",
"----------------------------------------\n",
"Model: gradient_boosting\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\pipeline.py:62: FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\pipeline.py:62: FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.\n",
2024-12-08 22:49:06 +04:00
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE (train): 0.0\n",
"MSE (test): 0.0\n",
"MAE (train): 0.0\n",
"MAE (test): 0.0\n",
"R2 (train): 1.0\n",
"R2 (test): 1.0\n",
"STD (train): 0.0\n",
"STD (test): 0.0\n",
"----------------------------------------\n",
"Model: random_forest\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\pipeline.py:62: FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\pipeline.py:62: FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.\n",
2024-12-08 22:49:06 +04:00
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"MSE (train): 0.0001551774499780926\n",
"MSE (test): 0.00025558638819921136\n",
"MAE (train): 0.0001551774499780926\n",
"MAE (test): 0.00025558638819921136\n",
"R2 (train): 0.9993688166957444\n",
"R2 (test): 0.9989613061229302\n",
"STD (train): 0.012456057559962977\n",
"STD (test): 0.015985026236993754\n",
2024-12-08 22:49:06 +04:00
"----------------------------------------\n",
"Model: mlp\n",
"MSE (train): 0.0206659851029648\n",
"MSE (test): 0.01971666423251059\n",
"MAE (train): 0.0206659851029648\n",
"MAE (test): 0.01971666423251059\n",
"R2 (train): 0.9159412352450146\n",
"R2 (test): 0.9198721866260421\n",
"STD (train): 0.1430221622813633\n",
"STD (test): 0.13976464450173487\n",
"----------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\pipeline.py:62: FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\pipeline.py:62: FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.\n",
2024-12-08 22:49:06 +04:00
" warnings.warn(\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.pipeline import Pipeline\n",
"\n",
"# Проверка наличия необходимых переменных\n",
"if 'class_models' not in locals():\n",
" raise ValueError(\"class_models is not defined\")\n",
"if 'X_train' not in locals() or 'X_test' not in locals() or 'y_train' not in locals() or 'y_test' not in locals():\n",
" raise ValueError(\"Train/test data is not defined\")\n",
"\n",
"# Преобразуем y_train и y_test в одномерные массивы\n",
"y_train = np.ravel(y_train) \n",
"y_test = np.ravel(y_test) \n",
"\n",
"# Инициализация списка для хранения результатов\n",
"results = []\n",
"\n",
"# Проход по моделям и оценка их качества\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" \n",
" # Извлечение модели из словаря\n",
" model = class_models[model_name][\"model\"]\n",
" \n",
" # Создание пайплайна\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" \n",
" # Обучение модели\n",
" model_pipeline.fit(X_train, y_train)\n",
"\n",
" # Предсказание для обучающей и тестовой выборки\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_predict = model_pipeline.predict(X_test)\n",
"\n",
" # Сохранение пайплайна и предсказаний\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" # Вычисление метрик для регрессии\n",
" class_models[model_name][\"MSE_train\"] = metrics.mean_squared_error(y_train, y_train_predict)\n",
" class_models[model_name][\"MSE_test\"] = metrics.mean_squared_error(y_test, y_test_predict)\n",
" class_models[model_name][\"MAE_train\"] = metrics.mean_absolute_error(y_train, y_train_predict)\n",
" class_models[model_name][\"MAE_test\"] = metrics.mean_absolute_error(y_test, y_test_predict)\n",
" class_models[model_name][\"R2_train\"] = metrics.r2_score(y_train, y_train_predict)\n",
" class_models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_predict)\n",
"\n",
" # Дополнительные метрики\n",
" class_models[model_name][\"STD_train\"] = np.std(y_train - y_train_predict)\n",
" class_models[model_name][\"STD_test\"] = np.std(y_test - y_test_predict)\n",
"\n",
" # Вывод результатов для текущей модели\n",
" print(f\"MSE (train): {class_models[model_name]['MSE_train']}\")\n",
" print(f\"MSE (test): {class_models[model_name]['MSE_test']}\")\n",
" print(f\"MAE (train): {class_models[model_name]['MAE_train']}\")\n",
" print(f\"MAE (test): {class_models[model_name]['MAE_test']}\")\n",
" print(f\"R2 (train): {class_models[model_name]['R2_train']}\")\n",
" print(f\"R2 (test): {class_models[model_name]['R2_test']}\")\n",
" print(f\"STD (train): {class_models[model_name]['STD_train']}\")\n",
" print(f\"STD (test): {class_models[model_name]['STD_test']}\")\n",
" print(\"-\" * 40) # Разделитель для разных моделей"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Пример использования обученной модели (конвейера регрессии) для предсказания"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 30,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"Model: RandomForest\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE (train): 0.003994806410348714\n",
"MSE (test): 0.08567146325791795\n",
"MAE (train): 0.002250011462828781\n",
"MAE (test): 0.009170197805056415\n",
"R2 (train): 0.9998844679299405\n",
"R2 (test): 0.9972453050448624\n",
2024-12-08 22:49:06 +04:00
"----------------------------------------\n",
2024-12-13 23:27:39 +04:00
"Прогнозируемая цена: -15.206353276691255\n"
2024-12-08 22:49:06 +04:00
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
2024-12-08 22:49:06 +04:00
" warnings.warn(\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.ensemble import RandomForestRegressor \n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"# 1. Загрузка данных\n",
2024-12-13 23:27:39 +04:00
"data = pd.read_csv(\"../../datasets/nuforc_reports.csv\") \n",
"data = data.head(15000)\n",
2024-12-08 22:49:06 +04:00
"\n",
"# 2. Подготовка данных для прогноза\n",
"average_city_latitude = data['city_latitude'].mean()\n",
"data['above_average_city_latitude'] = (data['city_latitude'] > average_city_latitude).astype(int) \n",
"\n",
"# Удаляем строки с пропущенными значениями в столбце 'city_latitude'\n",
"data = data.dropna(subset=['city_latitude'])\n",
"\n",
"# Предикторы и целевая переменная\n",
"X = data.drop('above_average_city_latitude', axis=1) # Удаляем только 'above_average_city_latitude'\n",
"y = data['city_latitude']\n",
"\n",
"# 3. Инициализация модели и пайплайна\n",
"class_models = {\n",
" \"RandomForest\": {\n",
" \"model\": RandomForestRegressor(n_estimators=100, random_state=42),\n",
" }\n",
"}\n",
"\n",
"# Предобработка признаков\n",
"num_columns = ['city_latitude']\n",
"cat_columns = ['state', 'city']\n",
"\n",
"# Проверка наличия столбцов перед предобработкой\n",
"required_columns = set(num_columns + cat_columns)\n",
"missing_columns = required_columns - set(X.columns)\n",
"if missing_columns:\n",
" raise KeyError(f\"Missing columns: {missing_columns}\")\n",
"\n",
"# Преобразование числовых признаков\n",
"num_transformer = Pipeline(steps=[\n",
" ('imputer', SimpleImputer(strategy='median')),\n",
" ('scaler', StandardScaler())\n",
"])\n",
"\n",
"# Преобразование категориальных признаков\n",
"cat_transformer = Pipeline(steps=[\n",
" ('imputer', SimpleImputer(strategy='constant', fill_value='unknown')),\n",
" ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False, drop=\"first\"))\n",
"])\n",
"\n",
"# Создание конвейера предобработки\n",
"preprocessor = ColumnTransformer(\n",
" transformers=[\n",
" ('num', num_transformer, num_columns),\n",
" ('cat', cat_transformer, cat_columns)\n",
" ])\n",
"\n",
"# Создание конвейера модели\n",
"pipeline_end = Pipeline(steps=[\n",
" ('preprocessor', preprocessor),\n",
" # ('model', model) # Модель добавляется в цикле\n",
"])\n",
"\n",
"results = []\n",
"\n",
"# 4. Обучение модели и оценка\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
"\n",
" model = class_models[model_name][\"model\"]\n",
" model_pipeline = Pipeline(steps=[\n",
" ('preprocessor', preprocessor),\n",
" ('model', model)\n",
" ])\n",
"\n",
" # Разделение данных\n",
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
"\n",
" # Обучение модели\n",
" model_pipeline.fit(X_train, y_train)\n",
"\n",
" # Предсказание\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_predict = model_pipeline.predict(X_test)\n",
"\n",
" # Сохранение результатов\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" # Вычисление метрик\n",
" class_models[model_name][\"MSE_train\"] = metrics.mean_squared_error(y_train, y_train_predict)\n",
" class_models[model_name][\"MSE_test\"] = metrics.mean_squared_error(y_test, y_test_predict)\n",
" class_models[model_name][\"MAE_train\"] = metrics.mean_absolute_error(y_train, y_train_predict)\n",
" class_models[model_name][\"MAE_test\"] = metrics.mean_absolute_error(y_test, y_test_predict)\n",
" class_models[model_name][\"R2_train\"] = metrics.r2_score(y_train, y_train_predict)\n",
" class_models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_predict)\n",
"\n",
" # Вывод результатов\n",
" print(f\"MSE (train): {class_models[model_name]['MSE_train']}\")\n",
" print(f\"MSE (test): {class_models[model_name]['MSE_test']}\")\n",
" print(f\"MAE (train): {class_models[model_name]['MAE_train']}\")\n",
" print(f\"MAE (test): {class_models[model_name]['MAE_test']}\")\n",
" print(f\"R2 (train): {class_models[model_name]['R2_train']}\")\n",
" print(f\"R2 (test): {class_models[model_name]['R2_test']}\")\n",
" print(\"-\" * 40)\n",
"\n",
"# Прогнозирование цены для нового товара\n",
"new_item_data = pd.DataFrame({\n",
" 'state': ['Electronics'],\n",
" 'city': ['Smartphones'], \n",
" 'city_latitude': [0] # Добавляем столбец 'city_latitude' с нулевым значением\n",
"})\n",
"\n",
"predicted_city_latitude = model_pipeline.predict(new_item_data)\n",
"print(f\"Прогнозируемая цена: {predicted_city_latitude[0]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 31,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
2024-12-08 22:49:06 +04:00
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 36 candidates, totalling 108 fits\n",
2024-12-13 23:27:39 +04:00
"Лучшие параметры: {'max_depth': 30, 'min_samples_split': 2, 'n_estimators': 200}\n",
"Лучший результат (MSE): 1.278829563088853\n"
2024-12-08 22:49:06 +04:00
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.impute import SimpleImputer\n",
"\n",
"# Удаление строк с пропущенными значениями (если необходимо)\n",
"df = df.dropna()\n",
"\n",
"# Создание целевой переменной (city_latitude)\n",
"target = df['city_latitude']\n",
"\n",
"# Удаление целевой переменной из исходных данных\n",
"features = df.drop(columns=['city_latitude'])\n",
"\n",
"# Удаление столбцов, которые не будут использоваться (например, href и items)\n",
"features = features.drop(columns=[\"date_time\", \"posted\", \"city\", \"state\", \"summary\", \"stats\", \"report_link\", \"duration\", \"text\"])\n",
"\n",
"# Определение столбцов для обработки\n",
"num_columns = features.select_dtypes(include=['number']).columns\n",
"cat_columns = features.select_dtypes(include=['object']).columns\n",
"\n",
"# Препроцессинг числовых столбцов\n",
"num_imputer = SimpleImputer(strategy=\"median\") # Используем медиану для заполнения пропущенных значений в числовых столбцах\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"# Препроцессинг категориальных столбцов\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\") # Используем 'unknown' для заполнения пропущенных значений в категориальных столбцах\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"# Объединение препроцессинга\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"# Создание финального пайплайна\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" ]\n",
")\n",
"\n",
"# Разделение данных на обучающую и тестовую выборки\n",
"X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)\n",
"\n",
"# Применение пайплайна к данным\n",
"X_train_processed = pipeline_end.fit_transform(X_train)\n",
"X_test_processed = pipeline_end.transform(X_test)\n",
"\n",
"# 2. Создание и настройка модели случайного леса\n",
"model = RandomForestRegressor()\n",
"\n",
"# Установка параметров для поиска по сетке\n",
"param_grid = {\n",
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
" 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n",
"}\n",
"\n",
"# 3. Подбор гиперпараметров с помощью Grid Search\n",
"grid_search = GridSearchCV(estimator=model, param_grid=param_grid,\n",
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"grid_search.fit(X_train_processed, y_train)\n",
"\n",
"# 4. Результаты подбора гиперпараметров\n",
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
"print(\"Лучший результат (MSE):\", -grid_search.best_score_) # Меняем знак, так как берем отрицательное значение среднеквадратичной ошибки"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами и сравнение новых и старых данных"
]
},
{
"cell_type": "code",
2024-12-13 23:27:39 +04:00
"execution_count": 33,
2024-12-08 22:49:06 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" summary city state \\\n",
"0 Viewed some red lights in the sky appearing to... Visalia CA \n",
"1 Look like 1 or 3 crafts from North traveling s... Cincinnati OH \n",
"3 One red light moving switly west to east, beco... Knoxville TN \n",
"5 I'm familiar with all the fakery and UFO sight... Fullerton CA \n",
"6 I was driving up lakes mead towards the lake a... Las Vegas NV \n",
"\n",
" date_time shape duration \\\n",
"0 2021-12-15T21:45:00 light 2 minutes \n",
"1 2021-12-16T09:45:00 triangle 14 seconds \n",
"3 2021-12-10T19:30:00 triangle 20-30 seconds \n",
"5 2020-07-07T23:00:00 unknown 2 minutes \n",
"6 2020-04-23T03:00:00 oval 10 minutes \n",
"\n",
" stats \\\n",
"0 Occurred : 12/15/2021 21:45 (Entered as : 12/... \n",
"1 Occurred : 12/16/2021 09:45 (Entered as : 12/... \n",
"3 Occurred : 12/10/2021 19:30 (Entered as : 12/... \n",
"5 Occurred : 7/7/2020 23:00 (Entered as : 07/07... \n",
"6 Occurred : 4/23/2020 03:00 (Entered as : 4/23... \n",
"\n",
" report_link \\\n",
"0 http://www.nuforc.org/webreports/165/S165881.html \n",
"1 http://www.nuforc.org/webreports/165/S165888.html \n",
"3 http://www.nuforc.org/webreports/165/S165825.html \n",
"5 http://www.nuforc.org/webreports/157/S157444.html \n",
"6 http://www.nuforc.org/webreports/155/S155608.html \n",
"\n",
" text posted \\\n",
"0 Viewed some red lights in the sky appearing to... 2021-12-19T00:00:00 \n",
"1 Look like 1 or 3 crafts from North traveling s... 2021-12-19T00:00:00 \n",
"3 One red light moving switly west to east, beco... 2021-12-19T00:00:00 \n",
"5 I'm familiar with all the fakery and UFO sight... 2020-07-09T00:00:00 \n",
"6 I was driving up lakes mead towards the lake a... 2020-05-01T00:00:00 \n",
"\n",
" city_latitude city_longitude \n",
"0 36.356650 -119.347937 \n",
"1 39.174503 -84.481363 \n",
"3 35.961561 -83.980115 \n",
"5 33.877422 -117.924978 \n",
"6 36.141246 -115.186592 \n",
"Fitting 3 folds for each of 36 candidates, totalling 108 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"C:\\Users\\Danil\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\sklearn\\preprocessing\\_encoders.py:246: UserWarning: Found unknown categories in columns [0, 1, 2, 3, 4] during transform. These unknown categories will be encoded as all zeros\n",
2024-12-08 22:49:06 +04:00
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-12-13 23:27:39 +04:00
"Старые параметры: {'max_depth': 10, 'min_samples_split': 2, 'n_estimators': 200}\n",
"Лучший результат (MSE) на старых параметрах: 0.5709161764408962\n",
2024-12-08 22:49:06 +04:00
"\n",
"Новые параметры: {'max_depth': 10, 'min_samples_split': 10, 'n_estimators': 200}\n",
2024-12-13 23:27:39 +04:00
"Лучший результат (MSE) на новых параметрах: 4.346034257469189\n",
"Среднеквадратическая ошибка (MSE) на тестовых данных: 0.1459878244079514\n",
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 0.3820835306682969\n"
2024-12-08 22:49:06 +04:00
]
},
{
"data": {
2024-12-13 23:27:39 +04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0kAAAHWCAYAAACi1sL/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/GU6VOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd1gURx/A8e/ROwiooKKg2LFhw44VbLHGGhVL1MSeGDXG3rDXxBYTS+xij8kbKwo2EAW7IkUsWAAB6W3fP04unBTBAAc6n+e5R253duZ3e8WdnSaTJElCEARBEARBEARBAEBN1QEIgiAIgiAIgiAUJaKSJAiCIAiCIAiCkIGoJAmCIAiCIAiCIGQgKkmCIAiCIAiCIAgZiEqSIAiCIAiCIAhCBqKSJAiCIAiCIAiCkIGoJAmCIAiCIAiCIGQgKkmCIAiCIAiCIAgZiEqSIAiCIAiCIAhCBqKSJAiCIAiCIGTrzz//xNfXV/H8yJEj3LlzR3UBCUIhEJUkQSjCAgICGDVqFBUrVkRHRwcjIyOaNWvGmjVriI+PV3V4giAIwmfg1q1bTJgwAX9/f65cucLo0aN5+/atqsMShAIlkyRJUnUQgiBkduLECb788ku0tbUZPHgwdnZ2JCUl4enpycGDB3FxcWHz5s2qDlMQBEH4xL1+/ZqmTZvy6NEjAHr27MnBgwdVHJUgFCxRSRKEIigoKIjatWtTrlw5zp49i6WlpdL+R48eceLECSZMmKCiCAVBEITPSWJiIrdv30ZPT4/q1aurOhxBKHCiu50gFEFLly4lJiaG3377LVMFCcDW1lapgiSTyRg7diy7du2iatWq6OjoUL9+fS5cuKB03OPHj/n222+pWrUqurq6mJmZ8eWXXxIcHKyUbtu2bchkMsVDT0+PWrVqsWXLFqV0Li4uGBgYZIrPzc0NmUyGu7u70varV6/i7OyMsbExenp6tGrViosXLyqlmTNnDjKZjLCwMKXt165dQyaTsW3bNqXyra2tldI9efIEXV1dZDJZptf1999/06JFC/T19TE0NKRz58656leffj4uXLjAqFGjMDMzw8jIiMGDB/PmzZtM6XNTzs2bN3FxcVF0pbSwsGDYsGGEh4dnGYO1tbXSe5L+yHiOra2t6dKlS46vJTg4GJlMxvLlyzPts7Ozw9HRUfHc3d0dmUyGm5tbtvm9/x7Mnj0bNTU1zpw5o5Ru5MiRaGlp4efnl2N8MpmMOXPmKG1btmwZMplMKbacjs/ukTHOjOdh1apVVKhQAV1dXVq1asXt27cz5Xv//n169+6NqakpOjo6NGjQgGPHjmUZg4uLS5blu7i4ZEr7999/06pVKwwNDTEyMqJhw4bs3r1bsd/R0THT6164cCFqampK6Tw8PPjyyy8pX7482traWFlZMWnSpEzdcufMmUONGjUwMDDAyMgIBwcHjhw5opQmt3nl5fvv6OiInZ1dprTLly/P9F390Oc4/XOZnv+9e/fQ1dVl8ODBSuk8PT1RV1dn6tSp2eYFuTsneYn/6NGjdO7cmTJlyqCtrU2lSpWYP38+qampSsdm9VlP/635mN+uvL4f73+uvL29FZ/VrOLU1tamfv36VK9ePU/fSUEorjRUHYAgCJkdP36cihUr0rRp01wfc/78efbt28f48ePR1tZm/fr1ODs74+XlpfjP3dvbm0uXLtGvXz/KlStHcHAwGzZswNHRkbt376Knp6eU56pVqzA3Nyc6Oprff/+dr7/+Gmtra9q1a5fn13T27Fk6duxI/fr1FRfSW7dupU2bNnh4eNCoUaM855mVWbNmkZCQkGn7H3/8wZAhQ3BycmLJkiXExcWxYcMGmjdvzo0bNzJVtrIyduxYTExMmDNnDg8ePGDDhg08fvxYcdGWl3JOnTpFYGAgQ4cOxcLCgjt37rB582bu3LnDlStXMl2oALRo0YKRI0cC8gvDRYsWffyJKiAzZszg+PHjDB8+nFu3bmFoaMg///zDr7/+yvz586lTp06e8ouMjMTV1TVPx7Rv3z7TBfOKFSuyrNDu2LGDt2/fMmbMGBISElizZg1t2rTh1q1blC5dGoA7d+7QrFkzypYty7Rp09DX12f//v10796dgwcP0qNHj0z5amtrK91UGDFiRKY027ZtY9iwYdSsWZMff/wRExMTbty4wf/+9z8GDBiQ5WvbunUrM2bMYMWKFUppDhw4QFxcHN988w1mZmZ4eXmxbt06nj59yoEDBxTpYmNj6dGjB9bW1sTHx7Nt2zZ69erF5cuXFd/B3OZVVFSvXp358+fzww8/0Lt3b7744gtiY2NxcXGhWrVqzJs3L8fjc3NO8mLbtm0YGBjw3XffYWBgwNmzZ5k1axbR0dEsW7Ysz/nlx29XbnyoMpnuY76TglAsSYIgFClRUVESIHXr1i3XxwASIF27dk2x7fHjx5KOjo7Uo0cPxba4uLhMx16+fFkCpB07dii2bd26VQKkoKAgxbaHDx9KgLR06VLFtiFDhkj6+vqZ8jxw4IAESOfOnZMkSZLS0tKkypUrS05OTlJaWppSPDY2NlL79u0V22bPni0B0uvXr5Xy9Pb2lgBp69atSuVXqFBB8fz27duSmpqa1LFjR6X43759K5mYmEhff/21Up4vXryQjI2NM21/X/r5qF+/vpSUlKTYvnTpUgmQjh49mudysnov9uzZIwHShQsXMu0rW7asNHToUMXzc+fOKZ1jSZKkChUqSJ07d87xtQQFBUmAtGzZskz7atasKbVq1SpTGQcOHMg2v/ffA0mSpFu3bklaWlrSiBEjpDdv3khly5aVGjRoICUnJ+cYmyTJP8uzZ89WPJ8yZYpUqlQpqX79+kqx5XT8mDFjMm3v3LmzUpzp50FXV1d6+vSpYvvVq1clQJo0aZJiW9u2baVatWpJCQkJim1paWlS06ZNpcqVK2cqa8CAAZKBgYHSNn19fWnIkCGK55GRkZKhoaHUuHFjKT4+Xiltxu9Iq1atFK/7xIkTkoaGhvT9999nKjOrz5Orq6skk8mkx48fZ9qX7tWrVxIgLV++PM955fb7n/46atasmSntsmXLMv3WfOhznNVnPzU1VWrevLlUunRpKSwsTBozZoykoaEheXt7Z5tPdrI6J3mJP6vzN2rUKElPT0/pMySTyaRZs2YppXv/tzcvvyl5fT8yfp/++usvCZCcnZ2l9y8N/+t3UhCKK9HdThCKmOjoaAAMDQ3zdFyTJk2oX7++4nn58uXp1q0b//zzj6Kbh66urmJ/cnIy4eHh2NraYmJiwvXr1zPl+ebNG8LCwggMDGTVqlWoq6vTqlWrTOnCwsKUHu/PeuTr64u/vz8DBgwgPDxckS42Npa2bdty4cIF0tLSlI6JiIhQyjMqKuqD5+DHH3/E3t6eL7/8Umn7qVOniIyMpH///kp5qqur07hxY86dO/fBvEHeZUxTU1Px/JtvvkFDQ4O//vorz+VkfC8SEhIICwvDwcEBIMv3IikpCW1t7Q/GmJycTFhYGOHh4aSkpGSbLi4uLtP79n53oHRv374lLCyMyMjID5YP8m57c+fOZcuWLTg5OREWFsb27dvR0Mhb54Vnz56xbt06Zs6cmWU3ovzQvXt3ypYtq3jeqFEjGjdurHhPIyIiOHv2LH369FGch/Tz6+TkhL+/P8+ePVPKMyEhAR0dnRzLPXXqFG/fvmXatGmZ0mbViujl5UWfPn3o1atXlq0RGT9PsbGxhIWF0bRpUyRJ4saNG0pp0z8jAQEBLF68GDU1NZo1a/ZRecGHv//pUlNTM6WNi4vLMm1uP8fp1NTU2LZtGzExMXTs2JH169fz448/0qBBgw8em7G87M5JXuLPeP7SPzMtWrQgLi6O+/fvK/aVKlWKp0+f5hjXx/x25fb9SCdJEj/++CO9evWicePGOaYtjO+kIBQVorudIBQxRkZGAHmeXrVy5cqZtlWpUoW4uDhev36NhYUF8fHxuLq6snXrVp49e4aUYd6WrCoh9vb2ir+1tbX5+eefM3U/iY2NpWTJkjnG5u/vD8CQIUOyTRMVFUWJEiUUz6tWrZpjnu/z9PTk+PH
2024-12-08 22:49:06 +04:00
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Загрузка датасета\n",
2024-12-13 23:27:39 +04:00
"df = pd.read_csv(\"../../datasets/nuforc_reports.csv\").head(100).dropna()\n",
2024-12-08 22:49:06 +04:00
"\n",
"# Вывод первых строк для проверки структуры\n",
"print(df.head())\n",
"\n",
"# Целевая переменная\n",
"target = df['city_latitude']\n",
"\n",
"# Удаление целевой переменной из признаков\n",
"features = df.drop(columns=['summary', 'stats', 'report_link', 'posted', \"duration\"])\n",
"\n",
"# Определение столбцов для обработки\n",
"num_columns = features.select_dtypes(include=['number']).columns\n",
"cat_columns = features.select_dtypes(include=['object']).columns\n",
"\n",
"# Препроцессинг числовых столбцов\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline([\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
"])\n",
"\n",
"# Препроцессинг категориальных столбцов\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline([\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
"])\n",
"\n",
"# Объединение препроцессинга\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"# Создание финального пайплайна\n",
"pipeline_end = Pipeline([\n",
" (\"features_preprocessing\", features_preprocessing),\n",
"])\n",
"\n",
"# Разделение данных на обучающую и тестовую выборки\n",
"X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)\n",
"\n",
"# Применение пайплайна к данным\n",
"X_train_processed = pipeline_end.fit_transform(X_train)\n",
"X_test_processed = pipeline_end.transform(X_test)\n",
"\n",
"# 1. Настройка параметров для старых значений\n",
"old_param_grid = {\n",
" 'n_estimators': [50, 100, 200],\n",
" 'max_depth': [None, 10, 20, 30],\n",
" 'min_samples_split': [2, 5, 10]\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью Grid Search для старых параметров\n",
"old_grid_search = GridSearchCV(estimator=RandomForestRegressor(),\n",
" param_grid=old_param_grid,\n",
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"old_grid_search.fit(X_train_processed, y_train)\n",
"\n",
"# Результаты подбора для старых параметров\n",
"old_best_params = old_grid_search.best_params_\n",
"old_best_mse = -old_grid_search.best_score_\n",
"\n",
"# 2. Настройка параметров для новых значений\n",
"new_param_grid = {\n",
" 'n_estimators': [200],\n",
" 'max_depth': [10],\n",
" 'min_samples_split': [10]\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью Grid Search для новых параметров\n",
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(),\n",
" param_grid=new_param_grid,\n",
" scoring='neg_mean_squared_error', cv=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"new_grid_search.fit(X_train_processed, y_train)\n",
"\n",
"# Результаты подбора для новых параметров\n",
"new_best_params = new_grid_search.best_params_\n",
"new_best_mse = -new_grid_search.best_score_\n",
"\n",
"# 5. Обучение модели с лучшими параметрами для новых значений\n",
"model_best = RandomForestRegressor(**new_best_params)\n",
"model_best.fit(X_train_processed, y_train)\n",
"\n",
"# Прогнозирование на тестовой выборке\n",
"y_pred = model_best.predict(X_test_processed)\n",
"\n",
"# Оценка производительности модели\n",
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"\n",
"# Вывод результатов\n",
"print(\"Старые параметры:\", old_best_params)\n",
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
"print(\"\\nН о вые параметры:\", new_best_params)\n",
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)\n",
"\n",
"# Обучение модели с лучшими параметрами для старых значений\n",
"model_old = RandomForestRegressor(**old_best_params)\n",
"model_old.fit(X_train_processed, y_train)\n",
"\n",
"# Прогнозирование на тестовой выборке для старых параметров\n",
"y_pred_old = model_old.predict(X_test_processed)\n",
"\n",
"# Визуализация ошибок\n",
"plt.figure(figsize=(10, 5))\n",
"plt.plot(y_test.values, label='Реальные значения', marker='o', linestyle='-', color='black')\n",
"plt.plot(y_pred_old, label='Предсказанные значения (старые параметры)', marker='x', linestyle='--', color='blue')\n",
"plt.plot(y_pred, label='Предсказанные значения (новые параметры)', marker='s', linestyle='--', color='orange')\n",
"plt.xlabel('Объекты')\n",
"plt.ylabel('Цена')\n",
"plt.title('Сравнение реальных и предсказанных значений')\n",
"plt.legend()\n",
"plt.show()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-12-13 23:27:39 +04:00
"version": "3.12.8"
2024-12-08 22:49:06 +04:00
}
},
"nbformat": 4,
"nbformat_minor": 2
}