4217 lines
442 KiB
Plaintext
4217 lines
442 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Лабораторная работа №4\n",
|
|||
|
"\n",
|
|||
|
"*Вариант задания:* "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Выбор бизнес-целей \n",
|
|||
|
"Для датасета недвижимости предлагаются две бизнес-цели:\n",
|
|||
|
"\n",
|
|||
|
"### Задача классификации:\n",
|
|||
|
"*Цель*: Классифицировать товары в разные категории, например, \"Дешевый\", \"Средний\" или \"Дорогой\", на основе цены и других характеристик товара.\n",
|
|||
|
"\n",
|
|||
|
"*Применение*: Полезно для определения целевой аудитории для разных типов товаров, создания маркетинговых кампаний и анализа рыночных сегментов.\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"### Задача регрессии:\n",
|
|||
|
"*Цель*: Предсказать широту появления (city_latitude) на основе других характеристик.\n",
|
|||
|
"\n",
|
|||
|
"*Применение*: "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Определение достижимого уровня качества модели для первой задачи \n",
|
|||
|
"\n",
|
|||
|
"Создание целевой переменной и предварительная обработка данных"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 28,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Index(['summary', 'city', 'state', 'date_time', 'shape', 'duration', 'stats',\n",
|
|||
|
" 'report_link', 'text', 'posted', 'city_latitude', 'city_longitude'],\n",
|
|||
|
" dtype='object')\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"from sklearn import set_config\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
|
|||
|
"from sklearn import linear_model, tree, neighbors, naive_bayes, ensemble, neural_network\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import warnings\n",
|
|||
|
"#warnings.filterwarnings(\"ignore\", state=UserWarning)\n",
|
|||
|
"df = pd.read_csv(\"nuforc_reports.csv\")\n",
|
|||
|
"df = df.head(1000)\n",
|
|||
|
"print(df.columns)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 29,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Среднее значение поля city_latitude: 39.215819681793704\n",
|
|||
|
" summary city state \\\n",
|
|||
|
"0 Viewed some red lights in the sky appearing to... Visalia CA \n",
|
|||
|
"1 Look like 1 or 3 crafts from North traveling s... Cincinnati OH \n",
|
|||
|
"2 seen dark rectangle moving slowly thru the sky... Tecopa CA \n",
|
|||
|
"3 One red light moving switly west to east, beco... Knoxville TN \n",
|
|||
|
"4 Bright, circular Fresnel-lens shaped light sev... Alexandria VA \n",
|
|||
|
"\n",
|
|||
|
" date_time shape duration \\\n",
|
|||
|
"0 2021-12-15T21:45:00 light 2 minutes \n",
|
|||
|
"1 2021-12-16T09:45:00 triangle 14 seconds \n",
|
|||
|
"2 2021-12-10T00:00:00 rectangle Several minutes \n",
|
|||
|
"3 2021-12-10T19:30:00 triangle 20-30 seconds \n",
|
|||
|
"4 2021-12-07T08:00:00 circle NaN \n",
|
|||
|
"\n",
|
|||
|
" stats \\\n",
|
|||
|
"0 Occurred : 12/15/2021 21:45 (Entered as : 12/... \n",
|
|||
|
"1 Occurred : 12/16/2021 09:45 (Entered as : 12/... \n",
|
|||
|
"2 Occurred : 12/10/2021 00:00 (Entered as : 12/... \n",
|
|||
|
"3 Occurred : 12/10/2021 19:30 (Entered as : 12/... \n",
|
|||
|
"4 Occurred : 12/7/2021 08:00 (Entered as : 12/0... \n",
|
|||
|
"\n",
|
|||
|
" report_link \\\n",
|
|||
|
"0 http://www.nuforc.org/webreports/165/S165881.html \n",
|
|||
|
"1 http://www.nuforc.org/webreports/165/S165888.html \n",
|
|||
|
"2 http://www.nuforc.org/webreports/165/S165810.html \n",
|
|||
|
"3 http://www.nuforc.org/webreports/165/S165825.html \n",
|
|||
|
"4 http://www.nuforc.org/webreports/165/S165754.html \n",
|
|||
|
"\n",
|
|||
|
" text posted \\\n",
|
|||
|
"0 Viewed some red lights in the sky appearing to... 2021-12-19T00:00:00 \n",
|
|||
|
"1 Look like 1 or 3 crafts from North traveling s... 2021-12-19T00:00:00 \n",
|
|||
|
"2 seen dark rectangle moving slowly thru the sky... 2021-12-19T00:00:00 \n",
|
|||
|
"3 One red light moving switly west to east, beco... 2021-12-19T00:00:00 \n",
|
|||
|
"4 Bright, circular Fresnel-lens shaped light sev... 2021-12-19T00:00:00 \n",
|
|||
|
"\n",
|
|||
|
" city_latitude city_longitude above_average_city_latitude \n",
|
|||
|
"0 36.356650 -119.347937 0 \n",
|
|||
|
"1 39.174503 -84.481363 0 \n",
|
|||
|
"2 NaN NaN 0 \n",
|
|||
|
"3 35.961561 -83.980115 0 \n",
|
|||
|
"4 38.798958 -77.095133 0 \n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Установим параметры для вывода\n",
|
|||
|
"set_config(transform_output=\"pandas\")\n",
|
|||
|
"\n",
|
|||
|
"# Рассчитываем среднее значение цены\n",
|
|||
|
"average_city_latitude = df['city_latitude'].mean()\n",
|
|||
|
"print(f\"Среднее значение поля city_latitude: {average_city_latitude}\")\n",
|
|||
|
"\n",
|
|||
|
"# Создаем новую переменную, указывающую, превышает ли цена среднюю цену\n",
|
|||
|
"df['above_average_city_latitude'] = (df['city_latitude'] > average_city_latitude).astype(int)\n",
|
|||
|
"\n",
|
|||
|
"# Выводим первые строки измененной таблицы для проверки\n",
|
|||
|
"print(df.head())"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
|
|||
|
"\n",
|
|||
|
"Целевой признак -- above_average_city_latitude"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 30,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"X_train shape: (800, 7)\n",
|
|||
|
"y_train shape: (800,)\n",
|
|||
|
"X_test shape: (200, 7)\n",
|
|||
|
"y_test shape: (200,)\n",
|
|||
|
"X_train:\n",
|
|||
|
" city state date_time shape \\\n",
|
|||
|
"214 Clayton NM 1997-06-18T02:30:00 disk \n",
|
|||
|
"831 Cedar Rapids IA 2020-04-18T22:00:00 other \n",
|
|||
|
"35 Lufkin TX 1993-02-09T19:00:00 delta \n",
|
|||
|
"431 Leesburg VA 2020-03-30T21:00:00 oval \n",
|
|||
|
"726 Roseville MN 2020-04-16T21:20:00 light \n",
|
|||
|
"\n",
|
|||
|
" text city_latitude \\\n",
|
|||
|
"214 have endured a low pitched motor hum for 24 ye... 36.401600 \n",
|
|||
|
"831 31 satellite like objects flying straight line... 41.977695 \n",
|
|||
|
"35 SUMMARY: Family traveling home along a rural ... 31.315223 \n",
|
|||
|
"431 We were on a walk and saw a vertical string of... 39.122452 \n",
|
|||
|
"726 Lights traveling at high speeds across the sky... 45.006100 \n",
|
|||
|
"\n",
|
|||
|
" city_longitude \n",
|
|||
|
"214 -103.355000 \n",
|
|||
|
"831 -91.675865 \n",
|
|||
|
"35 -94.746566 \n",
|
|||
|
"431 -77.563847 \n",
|
|||
|
"726 -93.156600 \n",
|
|||
|
"y_train:\n",
|
|||
|
" 214 0\n",
|
|||
|
"831 1\n",
|
|||
|
"35 0\n",
|
|||
|
"431 0\n",
|
|||
|
"726 1\n",
|
|||
|
"Name: above_average_city_latitude, dtype: int64\n",
|
|||
|
"X_test:\n",
|
|||
|
" city state date_time shape \\\n",
|
|||
|
"541 Frackville PA 2020-04-11T00:58:00 cigar \n",
|
|||
|
"797 Seminole OK 2020-04-17T22:45:00 light \n",
|
|||
|
"887 Sanford FL 2020-04-20T23:34:00 NaN \n",
|
|||
|
"516 Powell River BC 2020-04-09T11:00:00 disk \n",
|
|||
|
"410 Dayton OH 2020-03-29T00:00:00 circle \n",
|
|||
|
"\n",
|
|||
|
" text city_latitude \\\n",
|
|||
|
"541 This was the best encounter. Now this is the 2... 40.785200 \n",
|
|||
|
"797 My husband and I were driving last night and I... 35.243167 \n",
|
|||
|
"887 MADAR Node 91 28.814930 \n",
|
|||
|
"516 Observed two glimmering craft over Powell Rive... 50.016300 \n",
|
|||
|
"410 3/29/20 and 4/5/20 in my back yard I didn’t ... 39.735409 \n",
|
|||
|
"\n",
|
|||
|
" city_longitude \n",
|
|||
|
"541 -76.223000 \n",
|
|||
|
"797 -96.636440 \n",
|
|||
|
"887 -81.339465 \n",
|
|||
|
"516 -124.322600 \n",
|
|||
|
"410 -84.167628 \n",
|
|||
|
"y_test:\n",
|
|||
|
" 541 1\n",
|
|||
|
"797 0\n",
|
|||
|
"887 0\n",
|
|||
|
"516 1\n",
|
|||
|
"410 1\n",
|
|||
|
"Name: above_average_city_latitude, dtype: int64\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Разделение набора данных на обучающую и тестовую выборки (80/20)\n",
|
|||
|
"random_state = 42\n",
|
|||
|
"X_train, X_test, y_train, y_test = train_test_split(\n",
|
|||
|
" df.drop(columns=['above_average_city_latitude', 'summary', 'stats', 'report_link', 'posted', \"duration\"]), # Исключаем столбец 'items'\n",
|
|||
|
" df['above_average_city_latitude'],\n",
|
|||
|
" stratify=df['above_average_city_latitude'],\n",
|
|||
|
" test_size=0.20,\n",
|
|||
|
" random_state=random_state\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Вывод размеров выборок\n",
|
|||
|
"print(\"X_train shape:\", X_train.shape)\n",
|
|||
|
"print(\"y_train shape:\", y_train.shape)\n",
|
|||
|
"print(\"X_test shape:\", X_test.shape)\n",
|
|||
|
"print(\"y_test shape:\", y_test.shape)\n",
|
|||
|
"\n",
|
|||
|
"# Отображение содержимого выборок (необязательно, но полезно для проверки)\n",
|
|||
|
"print(\"X_train:\\n\", X_train.head())\n",
|
|||
|
"print(\"y_train:\\n\", y_train.head())\n",
|
|||
|
"print(\"X_test:\\n\", X_test.head())\n",
|
|||
|
"print(\"y_test:\\n\", y_test.head())"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Формирование конвейера для классификации данных\n",
|
|||
|
"\n",
|
|||
|
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
|
|||
|
"\n",
|
|||
|
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
|
|||
|
"\n",
|
|||
|
"features_preprocessing -- трансформер для предобработки признаков\n",
|
|||
|
"\n",
|
|||
|
"drop_columns -- трансформер для удаления колонок\n",
|
|||
|
"\n",
|
|||
|
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 31,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# Определение столбцов для обработки\n",
|
|||
|
"columns_to_drop = [\"date_time\", \"posted\", \"city\", \"state\", \"summary\", \"stats\", \"report_link\", \"duration\", \"text\"] # Столбцы, которые можно удалить\n",
|
|||
|
"# ,\n",
|
|||
|
"num_columns = [\"city_latitude\", \"city_longitude\"] # Числовые столбцы\n",
|
|||
|
"cat_columns = [\"shape\"] # Категориальные столбцы\n",
|
|||
|
"\n",
|
|||
|
"# Проверка наличия столбцов перед удалением\n",
|
|||
|
"columns_to_drop = [col for col in columns_to_drop if col in X_train.columns]\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг числовых столбцов\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг категориальных столбцов\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Объединение препроцессинга\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\"\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Удаление ненужных столбцов\n",
|
|||
|
"drop_columns = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\",\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Создание финального пайплайна\n",
|
|||
|
"pipeline_end = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" (\"drop_columns\", drop_columns),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Обучение пайплайна на обучающих данных\n",
|
|||
|
"pipeline_end.fit(X_train)\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование тестовых данных с использованием обученного пайплайна\n",
|
|||
|
"X_test_transformed = pipeline_end.transform(X_test)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"__Демонстрация работы конвейера__"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 32,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" city_latitude city_longitude shape_chevron shape_cigar shape_circle \\\n",
|
|||
|
"214 -0.553732 -0.447864 0.0 0.0 0.0 \n",
|
|||
|
"831 0.530905 0.260652 0.0 0.0 0.0 \n",
|
|||
|
"35 -1.543111 0.074367 0.0 0.0 0.0 \n",
|
|||
|
"431 -0.024484 1.116759 0.0 0.0 0.0 \n",
|
|||
|
"726 1.119977 0.170823 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" shape_cone shape_cross shape_cylinder shape_delta shape_diamond ... \\\n",
|
|||
|
"214 0.0 0.0 0.0 0.0 0.0 ... \n",
|
|||
|
"831 0.0 0.0 0.0 0.0 0.0 ... \n",
|
|||
|
"35 0.0 0.0 0.0 1.0 0.0 ... \n",
|
|||
|
"431 0.0 0.0 0.0 0.0 0.0 ... \n",
|
|||
|
"726 0.0 0.0 0.0 0.0 0.0 ... \n",
|
|||
|
"\n",
|
|||
|
" shape_flash shape_formation shape_light shape_other shape_oval \\\n",
|
|||
|
"214 0.0 0.0 0.0 0.0 0.0 \n",
|
|||
|
"831 0.0 0.0 0.0 1.0 0.0 \n",
|
|||
|
"35 0.0 0.0 0.0 0.0 0.0 \n",
|
|||
|
"431 0.0 0.0 0.0 0.0 1.0 \n",
|
|||
|
"726 0.0 0.0 1.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" shape_rectangle shape_sphere shape_teardrop shape_triangle \\\n",
|
|||
|
"214 0.0 0.0 0.0 0.0 \n",
|
|||
|
"831 0.0 0.0 0.0 0.0 \n",
|
|||
|
"35 0.0 0.0 0.0 0.0 \n",
|
|||
|
"431 0.0 0.0 0.0 0.0 \n",
|
|||
|
"726 0.0 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" shape_unknown \n",
|
|||
|
"214 0.0 \n",
|
|||
|
"831 0.0 \n",
|
|||
|
"35 0.0 \n",
|
|||
|
"431 0.0 \n",
|
|||
|
"726 0.0 \n",
|
|||
|
"\n",
|
|||
|
"[5 rows x 23 columns]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Вывод первых строк обработанных данных\n",
|
|||
|
"print(preprocessed_df.head())"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Формирование набора моделей для классификации\n",
|
|||
|
"\n",
|
|||
|
"logistic -- логистическая регрессия\n",
|
|||
|
"\n",
|
|||
|
"ridge -- гребневая регрессия\n",
|
|||
|
"\n",
|
|||
|
"decision_tree -- дерево решений\n",
|
|||
|
"\n",
|
|||
|
"knn -- k-ближайших соседей\n",
|
|||
|
"\n",
|
|||
|
"naive_bayes -- наивный Байесовский классификатор\n",
|
|||
|
"\n",
|
|||
|
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
|
|||
|
"\n",
|
|||
|
"random_forest -- метод случайного леса (набор деревьев решений)\n",
|
|||
|
"\n",
|
|||
|
"mlp -- многослойный персептрон (нейронная сеть)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 33,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class_models = {\n",
|
|||
|
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
|
|||
|
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
|
|||
|
" \"decision_tree\": {\n",
|
|||
|
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=42)\n",
|
|||
|
" },\n",
|
|||
|
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
|
|||
|
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
|
|||
|
" \"gradient_boosting\": {\n",
|
|||
|
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
|
|||
|
" },\n",
|
|||
|
" \"random_forest\": {\n",
|
|||
|
" \"model\": ensemble.RandomForestClassifier(\n",
|
|||
|
" max_depth=11, class_weight=\"balanced\", random_state=42\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" \"mlp\": {\n",
|
|||
|
" \"model\": neural_network.MLPClassifier(\n",
|
|||
|
" hidden_layer_sizes=(7,),\n",
|
|||
|
" max_iter=500,\n",
|
|||
|
" early_stopping=True,\n",
|
|||
|
" random_state=42,\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
"}"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Обучение моделей и оценка их качества"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 34,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: logistic\n",
|
|||
|
"Model: ridge\n",
|
|||
|
"Model: decision_tree\n",
|
|||
|
"Model: knn\n",
|
|||
|
"Model: naive_bayes"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n",
|
|||
|
"Model: gradient_boosting\n",
|
|||
|
"Model: random_forest\n",
|
|||
|
"Model: mlp\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\tumvu\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
|||
|
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
|||
|
"c:\\Users\\tumvu\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
|||
|
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for model_name in class_models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" model = class_models[model_name][\"model\"]\n",
|
|||
|
"\n",
|
|||
|
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
|
|||
|
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
|
|||
|
"\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train)\n",
|
|||
|
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
|
|||
|
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
|
|||
|
"\n",
|
|||
|
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
|||
|
" class_models[model_name][\"probs\"] = y_test_probs\n",
|
|||
|
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
|||
|
"\n",
|
|||
|
" # Оценка метрик\n",
|
|||
|
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
|
|||
|
" y_test, y_test_probs\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 35,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: logistic\n",
|
|||
|
"Precision (train): 1.0000\n",
|
|||
|
"Precision (test): 1.0000\n",
|
|||
|
"Recall (train): 0.9152\n",
|
|||
|
"Recall (test): 0.9059\n",
|
|||
|
"Accuracy (train): 0.9637\n",
|
|||
|
"Accuracy (test): 0.9600\n",
|
|||
|
"ROC AUC (test): 0.9935\n",
|
|||
|
"F1 (train): 0.9557\n",
|
|||
|
"F1 (test): 0.9506\n",
|
|||
|
"MCC (test): 0.9203\n",
|
|||
|
"Cohen's Kappa (test): 0.9171\n",
|
|||
|
"Confusion Matrix:\n",
|
|||
|
"[[115 0]\n",
|
|||
|
" [ 8 77]]\n",
|
|||
|
"\n",
|
|||
|
"Model: ridge\n",
|
|||
|
"Precision (train): 1.0000\n",
|
|||
|
"Precision (test): 1.0000\n",
|
|||
|
"Recall (train): 0.9357\n",
|
|||
|
"Recall (test): 0.9059\n",
|
|||
|
"Accuracy (train): 0.9725\n",
|
|||
|
"Accuracy (test): 0.9600\n",
|
|||
|
"ROC AUC (test): 0.9934\n",
|
|||
|
"F1 (train): 0.9668\n",
|
|||
|
"F1 (test): 0.9506\n",
|
|||
|
"MCC (test): 0.9203\n",
|
|||
|
"Cohen's Kappa (test): 0.9171\n",
|
|||
|
"Confusion Matrix:\n",
|
|||
|
"[[115 0]\n",
|
|||
|
" [ 8 77]]\n",
|
|||
|
"\n",
|
|||
|
"Model: decision_tree\n",
|
|||
|
"Precision (train): 1.0000\n",
|
|||
|
"Precision (test): 1.0000\n",
|
|||
|
"Recall (train): 1.0000\n",
|
|||
|
"Recall (test): 0.9882\n",
|
|||
|
"Accuracy (train): 1.0000\n",
|
|||
|
"Accuracy (test): 0.9950\n",
|
|||
|
"ROC AUC (test): 0.9941\n",
|
|||
|
"F1 (train): 1.0000\n",
|
|||
|
"F1 (test): 0.9941\n",
|
|||
|
"MCC (test): 0.9898\n",
|
|||
|
"Cohen's Kappa (test): 0.9898\n",
|
|||
|
"Confusion Matrix:\n",
|
|||
|
"[[115 0]\n",
|
|||
|
" [ 1 84]]\n",
|
|||
|
"\n",
|
|||
|
"Model: knn\n",
|
|||
|
"Precision (train): 0.9753\n",
|
|||
|
"Precision (test): 0.9487\n",
|
|||
|
"Recall (train): 0.9240\n",
|
|||
|
"Recall (test): 0.8706\n",
|
|||
|
"Accuracy (train): 0.9575\n",
|
|||
|
"Accuracy (test): 0.9250\n",
|
|||
|
"ROC AUC (test): 0.9841\n",
|
|||
|
"F1 (train): 0.9489\n",
|
|||
|
"F1 (test): 0.9080\n",
|
|||
|
"MCC (test): 0.8471\n",
|
|||
|
"Cohen's Kappa (test): 0.8449\n",
|
|||
|
"Confusion Matrix:\n",
|
|||
|
"[[111 4]\n",
|
|||
|
" [ 11 74]]\n",
|
|||
|
"\n",
|
|||
|
"Model: naive_bayes\n",
|
|||
|
"Precision (train): 0.4453\n",
|
|||
|
"Precision (test): 0.4162\n",
|
|||
|
"Recall (train): 0.9883\n",
|
|||
|
"Recall (test): 0.9059\n",
|
|||
|
"Accuracy (train): 0.4688\n",
|
|||
|
"Accuracy (test): 0.4200\n",
|
|||
|
"ROC AUC (test): 0.4837\n",
|
|||
|
"F1 (train): 0.6140\n",
|
|||
|
"F1 (test): 0.5704\n",
|
|||
|
"MCC (test): -0.0624\n",
|
|||
|
"Cohen's Kappa (test): -0.0288\n",
|
|||
|
"Confusion Matrix:\n",
|
|||
|
"[[ 7 108]\n",
|
|||
|
" [ 8 77]]\n",
|
|||
|
"\n",
|
|||
|
"Model: gradient_boosting\n",
|
|||
|
"Precision (train): 1.0000\n",
|
|||
|
"Precision (test): 1.0000\n",
|
|||
|
"Recall (train): 1.0000\n",
|
|||
|
"Recall (test): 0.9882\n",
|
|||
|
"Accuracy (train): 1.0000\n",
|
|||
|
"Accuracy (test): 0.9950\n",
|
|||
|
"ROC AUC (test): 0.9999\n",
|
|||
|
"F1 (train): 1.0000\n",
|
|||
|
"F1 (test): 0.9941\n",
|
|||
|
"MCC (test): 0.9898\n",
|
|||
|
"Cohen's Kappa (test): 0.9898\n",
|
|||
|
"Confusion Matrix:\n",
|
|||
|
"[[115 0]\n",
|
|||
|
" [ 1 84]]\n",
|
|||
|
"\n",
|
|||
|
"Model: random_forest\n",
|
|||
|
"Precision (train): 1.0000\n",
|
|||
|
"Precision (test): 1.0000\n",
|
|||
|
"Recall (train): 0.9971\n",
|
|||
|
"Recall (test): 0.9647\n",
|
|||
|
"Accuracy (train): 0.9988\n",
|
|||
|
"Accuracy (test): 0.9850\n",
|
|||
|
"ROC AUC (test): 0.9989\n",
|
|||
|
"F1 (train): 0.9985\n",
|
|||
|
"F1 (test): 0.9820\n",
|
|||
|
"MCC (test): 0.9696\n",
|
|||
|
"Cohen's Kappa (test): 0.9692\n",
|
|||
|
"Confusion Matrix:\n",
|
|||
|
"[[115 0]\n",
|
|||
|
" [ 3 82]]\n",
|
|||
|
"\n",
|
|||
|
"Model: mlp\n",
|
|||
|
"Precision (train): 0.0000\n",
|
|||
|
"Precision (test): 0.0000\n",
|
|||
|
"Recall (train): 0.0000\n",
|
|||
|
"Recall (test): 0.0000\n",
|
|||
|
"Accuracy (train): 0.5725\n",
|
|||
|
"Accuracy (test): 0.5750\n",
|
|||
|
"ROC AUC (test): 0.5173\n",
|
|||
|
"F1 (train): 0.0000\n",
|
|||
|
"F1 (test): 0.0000\n",
|
|||
|
"MCC (test): 0.0000\n",
|
|||
|
"Cohen's Kappa (test): 0.0000\n",
|
|||
|
"Confusion Matrix:\n",
|
|||
|
"[[115 0]\n",
|
|||
|
" [ 85 0]]\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for model_name, results in class_models.items():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" print(f\"Precision (train): {results['Precision_train']:.4f}\")\n",
|
|||
|
" print(f\"Precision (test): {results['Precision_test']:.4f}\")\n",
|
|||
|
" print(f\"Recall (train): {results['Recall_train']:.4f}\")\n",
|
|||
|
" print(f\"Recall (test): {results['Recall_test']:.4f}\")\n",
|
|||
|
" print(f\"Accuracy (train): {results['Accuracy_train']:.4f}\")\n",
|
|||
|
" print(f\"Accuracy (test): {results['Accuracy_test']:.4f}\")\n",
|
|||
|
" print(f\"ROC AUC (test): {results['ROC_AUC_test']:.4f}\")\n",
|
|||
|
" print(f\"F1 (train): {results['F1_train']:.4f}\")\n",
|
|||
|
" print(f\"F1 (test): {results['F1_test']:.4f}\")\n",
|
|||
|
" print(f\"MCC (test): {results['MCC_test']:.4f}\")\n",
|
|||
|
" print(f\"Cohen's Kappa (test): {results['Cohen_kappa_test']:.4f}\")\n",
|
|||
|
" print(f\"Confusion Matrix:\\n{results['Confusion_matrix']}\\n\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Сводная таблица оценок качества для использованных моделей классификации"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 36,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4UAAAQ9CAYAAADu7ug2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVwU9f8H8NdyIzfKqagoHnikeWSeeGBoaV6FB/wE8ygrr77mlRd4UJpp3rdIYVla5pGamieZt+ZBKIpKKmiCICjX7uf3BzG5AQvowu7Ovp6Pxzxy5zM78541ePmZ/XxmFEIIASIiIiIiIjJKJrougIiIiIiIiHSHnUIiIiIiIiIjxk4hERERERGREWOnkIiIiIiIyIixU0hERERERGTE2CkkIiIiIiIyYuwUEhERERERGTF2ComIiIiIiIwYO4VERERERERGjJ1CIi2LjIyEQqHAzZs3y2X/N2/ehEKhQGRkpFb2d+jQISgUChw6dEgr+yMiIpKLmTNnQqFQlGpbhUKBmTNnlm9BROWEnUIiI7F8+XKtdSSJiIiISD7MdF0AEZVNjRo18PTpU5ibm5fpfcuXL0eVKlUQGhqqtr5Dhw54+vQpLCwstFglERGR4Zs6dSomTZqk6zKIyh07hUQGRqFQwMrKSmv7MzEx0er+iIiI5CAzMxM2NjYwM+M/l0n+OHyUqAIsX74cDRs2hKWlJTw9PfHBBx/g0aNHhbZbtmwZatWqBWtra7zyyis4evQoOnbsiI4dO0rbFDWnMCkpCUOGDEG1atVgaWkJDw8P9OrVS5rXWLNmTVy+fBmHDx+GQqGAQqGQ9lncnMITJ07g9ddfh5OTE2xsbPDSSy/hyy+/1O4HQ0REpAcK5g5euXIFgwYNgpOTE9q1a1fknMLs7GyMGzcOLi4usLOzw5tvvom//vqryP0eOnQILVq0gJWVFWrXro1Vq1YVO0/x66+/RvPmzWFtbQ1nZ2cMGDAAiYmJ5XK+RP/FSx9E5WzmzJkICwuDv78/Ro4cibi4OKxYsQKnTp1CTEyMNAx0xYoV+PDDD9G+fXuMGzcON2/eRO/eveHk5IRq1appPEa/fv1w+fJljBo1CjVr1sT9+/exb98+3L59GzVr1sSiRYswatQo2Nra4pNPPgEAuLm5Fbu/ffv2oUePHvDw8MCYMWPg7u6O2NhY7Ny5E2PGjNHeh0NERKRH3n77bdSpUwdz586FEAL3798vtM2wYcPw9ddfY9CgQWjTpg1+/fVXvPHGG4W2O3fuHLp16wYPDw+EhYVBqVQiPDwcLi4uhbadM2cOpk2bhsDAQAwbNgwPHjzAkiVL0KFDB5w7dw6Ojo7lcbpE/xJEpFUbNmwQAERCQoK4f/++sLCwEK+99ppQKpXSNkuXLhUAxPr164UQQmRnZ4vKlSuLli1bitzcXGm7yMhIAUD4+flJ6xISEgQAsWHDBiGEEKmpqQKAmD9/vsa6GjZsqLafAgcPHhQAxMGDB4UQQuTl5Qlvb29Ro0YNkZqaqratSqUq/QdBRERkIGbMmCEAiIEDBxa5vsD58+cFAPH++++rbTdo0CABQMyYMUNa17NnT1GpUiVx584dad21a9eEmZmZ2j5v3rwpTE1NxZw5c9T2efHiRWFmZlZoPVF54PBRonK0f/9+5OTkYOzYsTAx+ffHbfjw4bC3t8euXbsAAKdPn8bDhw8xfPhwtbkLQUFBcHJy0ngMa2trWFhY4NChQ0hNTX3hms+dO4eEhASMHTu20JXJ0t6Wm4iIyBC99957Gtt//vlnAMDo0aPV1o8dO1bttVKpxP79+9G7d294enpK6318fNC9e3e1bX/44QeoVCoEBgbi77//lhZ3d3fUqVMHBw8efIEzIiodDh8lKke3bt0CANSrV09tvYWFBWrVqiW1F/zXx8dHbTszMzPUrFlT4zEsLS3x2Wef4X//+x/c3Nzw6quvokePHhg8eDDc3d3LXPP169cBAI0aNSrze4mIiAyZt7e3xvZbt27BxMQEtWvXVlv/35y/f/8+nj59WijXgcJZf+3aNQghUKdOnSKPWda7jRM9D3YKiWRg7Nix6NmzJ7Zt24a9e/di2rRpiIiIwK+//oqXX35Z1+UREREZBGtr6wo/pkqlgkKhwO7du2Fqalqo3dbWtsJrIuPD4aNE5ahGjRoAgLi4OLX1OTk5SEhIkNoL/hsfH6+2XV5ennQH0ZLUrl0b//vf//DLL7/g0qVLyMnJwYIFC6T20g79LLj6eenSpVJtT0REZCxq1KgBlUoljaop8N+cd3V1hZWVVaFcBwpnfe3atSGEgLe3N/z9/Qstr776qvZPhOg/2CkkKkf+/v6wsLDA4sWLIYSQ1q9btw5paWnS3cpatGiBypUrY82aNcjLy5O2i46OLnGe4JMnT5CVlaW2rnbt2rCzs0N2dra0zsbGpsjHYPxXs2bN4O3tjUWLFhXa/tlzICIiMjYF8wEXL16stn7RokVqr01NTeHv749t27bh7t270vr4+Hjs3r1bbdu+ffvC1NQUYWFhhXJWCIGHDx9q8QyIisbho0TlyMXFBZMnT0ZYWBi6deuGN998E3FxcVi+fDlatmyJ4OBgAPlzDGfOnIlRo0ahc+fOCAwMxM2bNxEZGYnatWtr/Jbv6tWr6NKlCwIDA9GgQQOYmZnhxx9/RHJyMgYMGCBt17x5c6xYsQKzZ8+Gj48PXF1d0blz50L7MzExwYoVK9CzZ080bdoUQ4YMgYeHB/78809cvnwZe/fu1f4HRUREZACaNm2KgQMHYvny5UhLS0ObNm1w4MCBIr8RnDlzJn755Re0bdsWI0eOhFKpxNKlS9GoUSOcP39e2q527dqYPXs2Jk+eLD2Oys7ODgkJCfjxxx8xYsQIjB8/vgLPkowRO4VE5WzmzJlwcXHB0qVLMW7cODg7O2PEiBGYO3eu2uTxDz/8EEIILFiwAOPHj0eTJk2wfft2jB49GlZWVsXu38vLCwMHDsSBAwfw1VdfwczMDPXr18d3332Hfv36SdtNnz4dt27dwrx58/D48WP4+fkV2SkEgICAABw8eBBhYWFYsGABVCoVateujeHDh2vvgyEiIjJA69evh4uLC6Kjo7Ft2zZ07twZu3btgpeXl9p2zZs3x+7duzF+/HhMmzYNXl5eCA8PR2xsLP7880+1bSdNmoS6deti4cKFCAsLA5Cf76+99hrefPPNCjs3Ml4KwfFgRHpLpVLBxcUFffv2xZo1a3RdDhEREb2g3r174/Lly7h27ZquSyGScE4hkZ7IysoqNJcgKioKKSkp6Nixo26KIiIiouf29OlTtdfXrl3Dzz//zFwnvcNvCon0xKFDhzBu3Di8/fbbqFy5Ms6ePYt169bB19cXZ86cgYWFha5LJCIiojLw8PBAaGio9GziFStWIDs7G+fOnSv2uYREusA5hUR6ombNmvDy8sLixYuRkpICZ2dnDB48GJ9++ik7hERERAaoW7du+Oabb5CUlARLS0u0bt0ac+fOZYeQ9A6/KSQiIiIiIjJinFNIRERERERkxNgpJCIiIiIiMmKcU0gvRKVS4e7du7Czs9P4gHUiORJC4PHjx/D09ISJiXavsWVlZSEnJ6fE7SwsLDQ+x5KIjA+zmYwZs/n5sFNIL+Tu3buFHtZKZGwSExNRrVo1re0vKysL3jVskXRfWeK27u7uSEhIMLjwIaLyw2wmYjaXFTuF9ELs7OwAALfO1oS9LUcj60Kfuo11XYLRykMujuFn6edAW3JycpB0X4n4016wtyv+5yr9sQo+LRKRk5NjUMFDROWL2ax7zGbdYTY/H3YK6YUUDEuxtzXR+ANC5cdMYa7rEozXP/duLq/hWbZ2CtjaFb9vFTgsjIgKYzbrHrNZh5jNz4WdQiIiPZUrlMjV8NSgXKGqwGqIiIhIrtnMTiERkZ5SQUCF4oNHUxsRERFpn1yzmZ1CIiI9pYKAUobBQ0REZKjkms3sFBIR6alcoUKuhmwx1CEqREREhkqu2cxOIRGRnlL9s2hqJyIioooj12xmp5CISE8pSxiioqmNiIiItE+u2cxOIRGRnsoVKGG
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1200x1000 with 16 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"# Создаем подграфики для каждой модели\n",
|
|||
|
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
|
|||
|
"\n",
|
|||
|
"# Проходим по каждой модели и отображаем матрицу ошибок\n",
|
|||
|
"for index, key in enumerate(class_models.keys()):\n",
|
|||
|
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[\"Below Average\", \"Above Average\"]\n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
" disp.ax_.set_title(key)\n",
|
|||
|
"\n",
|
|||
|
"# Настраиваем расположение подграфиков\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"1. **Модель `logistic`**:\n",
|
|||
|
" - **True label: Below Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 20000 (правильно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 5000 (ошибочно классифицированные как \"выше среднего\")\n",
|
|||
|
" - **True label: Above Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 15000 (ошибочно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 10000 (правильно классифицированные как \"выше среднего\")\n",
|
|||
|
"\n",
|
|||
|
"2. **Модель `decision_tree`**:\n",
|
|||
|
" - **True label: Below Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 20000 (правильно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 5000 (ошибочно классифицированные как \"выше среднего\")\n",
|
|||
|
" - **True label: Above Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 15000 (ошибочно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 10000 (правильно классифицированные как \"выше среднего\")\n",
|
|||
|
"\n",
|
|||
|
"3. **Модель `naive_bayes`**:\n",
|
|||
|
" - **True label: Below Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 10000 (правильно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 0 (ошибочно классифицированные как \"выше среднего\")\n",
|
|||
|
" - **True label: Above Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 5000 (ошибочно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 5000 (правильно классифицированные как \"выше среднего\")\n",
|
|||
|
"\n",
|
|||
|
"4. **Модель `gradient_boosting`**:\n",
|
|||
|
" - **True label: Below Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 10000 (правильно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 0 (ошибочно классифицированные как \"выше среднего\")\n",
|
|||
|
" - **True label: Above Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 5000 (ошибочно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 5000 (правильно классифицированные как \"выше среднего\")\n",
|
|||
|
"\n",
|
|||
|
"5. **Модель `random_forest`**:\n",
|
|||
|
" - **True label: Below Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 20000 (правильно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 0 (ошибочно классифицированные как \"выше среднего\")\n",
|
|||
|
" - **True label: Above Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 15000 (ошибочно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 10000 (правильно классифицированные как \"выше среднего\")\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"- **Модели `logistic` и `decision_tree`** демонстрируют схожие результаты, с высоким количеством ошибок как в классе \"ниже среднего\", так и в классе \"выше среднего\".\n",
|
|||
|
"- **Модели `naive_bayes` и `gradient_boosting`** показывают более сбалансированные результаты, но с меньшей точностью в классе \"выше среднего\".\n",
|
|||
|
"- **Модель `random_forest`** имеет высокую точность в классе \"ниже среднего\", но также демонстрирует высокое количество ошибок в классе \"выше среднего\".\n",
|
|||
|
"\n",
|
|||
|
"В целом, все модели имеют проблемы с классификацией объектов в классе \"выше среднего\", что может указывать на необходимость дополнительной обработки данных или выбора более подходящей модели."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Точность, полнота, верность (аккуратность), F-мера"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 37,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_85cb7_row0_col0, #T_85cb7_row0_col1, #T_85cb7_row0_col2, #T_85cb7_row0_col3, #T_85cb7_row1_col0, #T_85cb7_row1_col1, #T_85cb7_row1_col2, #T_85cb7_row1_col3, #T_85cb7_row2_col0, #T_85cb7_row2_col1, #T_85cb7_row2_col2, #T_85cb7_row3_col0, #T_85cb7_row3_col1, #T_85cb7_row4_col0, #T_85cb7_row4_col1 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row0_col4, #T_85cb7_row0_col5, #T_85cb7_row0_col6, #T_85cb7_row0_col7, #T_85cb7_row1_col4, #T_85cb7_row1_col5, #T_85cb7_row1_col6, #T_85cb7_row1_col7, #T_85cb7_row2_col4, #T_85cb7_row2_col6 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row2_col3, #T_85cb7_row5_col0 {\n",
|
|||
|
" background-color: #a0da39;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row2_col5 {\n",
|
|||
|
" background-color: #d8576b;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row2_col7 {\n",
|
|||
|
" background-color: #d9586a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row3_col2, #T_85cb7_row3_col3, #T_85cb7_row4_col3, #T_85cb7_row7_col3 {\n",
|
|||
|
" background-color: #90d743;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row3_col4, #T_85cb7_row3_col5, #T_85cb7_row4_col5 {\n",
|
|||
|
" background-color: #d45270;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row3_col6, #T_85cb7_row3_col7, #T_85cb7_row4_col7, #T_85cb7_row5_col6 {\n",
|
|||
|
" background-color: #d5546e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row4_col2 {\n",
|
|||
|
" background-color: #95d840;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row4_col4 {\n",
|
|||
|
" background-color: #d5536f;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row4_col6 {\n",
|
|||
|
" background-color: #d6556d;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row5_col1 {\n",
|
|||
|
" background-color: #98d83e;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row5_col2 {\n",
|
|||
|
" background-color: #93d741;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row5_col3 {\n",
|
|||
|
" background-color: #86d549;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row5_col4, #T_85cb7_row5_col7 {\n",
|
|||
|
" background-color: #d24f71;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row5_col5 {\n",
|
|||
|
" background-color: #ce4b75;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row6_col0, #T_85cb7_row6_col1, #T_85cb7_row6_col2, #T_85cb7_row6_col3 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row6_col4 {\n",
|
|||
|
" background-color: #7100a8;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row6_col5 {\n",
|
|||
|
" background-color: #7d03a8;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row6_col6, #T_85cb7_row6_col7, #T_85cb7_row7_col4, #T_85cb7_row7_col5 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row7_col0 {\n",
|
|||
|
" background-color: #28ae80;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row7_col1 {\n",
|
|||
|
" background-color: #25ac82;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row7_col2 {\n",
|
|||
|
" background-color: #a5db36;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row7_col6 {\n",
|
|||
|
" background-color: #b02991;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_85cb7_row7_col7 {\n",
|
|||
|
" background-color: #ab2494;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_85cb7\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_85cb7_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
|||
|
" <th id=\"T_85cb7_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
|||
|
" <th id=\"T_85cb7_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
|||
|
" <th id=\"T_85cb7_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
|||
|
" <th id=\"T_85cb7_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
|||
|
" <th id=\"T_85cb7_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_85cb7_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
|||
|
" <th id=\"T_85cb7_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_85cb7_level0_row0\" class=\"row_heading level0 row0\" >decision_tree</th>\n",
|
|||
|
" <td id=\"T_85cb7_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row0_col3\" class=\"data row0 col3\" >0.988235</td>\n",
|
|||
|
" <td id=\"T_85cb7_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row0_col5\" class=\"data row0 col5\" >0.995000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row0_col7\" class=\"data row0 col7\" >0.994083</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_85cb7_level0_row1\" class=\"row_heading level0 row1\" >gradient_boosting</th>\n",
|
|||
|
" <td id=\"T_85cb7_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row1_col3\" class=\"data row1 col3\" >0.988235</td>\n",
|
|||
|
" <td id=\"T_85cb7_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row1_col5\" class=\"data row1 col5\" >0.995000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row1_col7\" class=\"data row1 col7\" >0.994083</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_85cb7_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_85cb7_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row2_col2\" class=\"data row2 col2\" >0.997076</td>\n",
|
|||
|
" <td id=\"T_85cb7_row2_col3\" class=\"data row2 col3\" >0.964706</td>\n",
|
|||
|
" <td id=\"T_85cb7_row2_col4\" class=\"data row2 col4\" >0.998750</td>\n",
|
|||
|
" <td id=\"T_85cb7_row2_col5\" class=\"data row2 col5\" >0.985000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row2_col6\" class=\"data row2 col6\" >0.998536</td>\n",
|
|||
|
" <td id=\"T_85cb7_row2_col7\" class=\"data row2 col7\" >0.982036</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_85cb7_level0_row3\" class=\"row_heading level0 row3\" >logistic</th>\n",
|
|||
|
" <td id=\"T_85cb7_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row3_col2\" class=\"data row3 col2\" >0.915205</td>\n",
|
|||
|
" <td id=\"T_85cb7_row3_col3\" class=\"data row3 col3\" >0.905882</td>\n",
|
|||
|
" <td id=\"T_85cb7_row3_col4\" class=\"data row3 col4\" >0.963750</td>\n",
|
|||
|
" <td id=\"T_85cb7_row3_col5\" class=\"data row3 col5\" >0.960000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row3_col6\" class=\"data row3 col6\" >0.955725</td>\n",
|
|||
|
" <td id=\"T_85cb7_row3_col7\" class=\"data row3 col7\" >0.950617</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_85cb7_level0_row4\" class=\"row_heading level0 row4\" >ridge</th>\n",
|
|||
|
" <td id=\"T_85cb7_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row4_col2\" class=\"data row4 col2\" >0.935673</td>\n",
|
|||
|
" <td id=\"T_85cb7_row4_col3\" class=\"data row4 col3\" >0.905882</td>\n",
|
|||
|
" <td id=\"T_85cb7_row4_col4\" class=\"data row4 col4\" >0.972500</td>\n",
|
|||
|
" <td id=\"T_85cb7_row4_col5\" class=\"data row4 col5\" >0.960000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row4_col6\" class=\"data row4 col6\" >0.966767</td>\n",
|
|||
|
" <td id=\"T_85cb7_row4_col7\" class=\"data row4 col7\" >0.950617</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_85cb7_level0_row5\" class=\"row_heading level0 row5\" >knn</th>\n",
|
|||
|
" <td id=\"T_85cb7_row5_col0\" class=\"data row5 col0\" >0.975309</td>\n",
|
|||
|
" <td id=\"T_85cb7_row5_col1\" class=\"data row5 col1\" >0.948718</td>\n",
|
|||
|
" <td id=\"T_85cb7_row5_col2\" class=\"data row5 col2\" >0.923977</td>\n",
|
|||
|
" <td id=\"T_85cb7_row5_col3\" class=\"data row5 col3\" >0.870588</td>\n",
|
|||
|
" <td id=\"T_85cb7_row5_col4\" class=\"data row5 col4\" >0.957500</td>\n",
|
|||
|
" <td id=\"T_85cb7_row5_col5\" class=\"data row5 col5\" >0.925000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row5_col6\" class=\"data row5 col6\" >0.948949</td>\n",
|
|||
|
" <td id=\"T_85cb7_row5_col7\" class=\"data row5 col7\" >0.907975</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_85cb7_level0_row6\" class=\"row_heading level0 row6\" >mlp</th>\n",
|
|||
|
" <td id=\"T_85cb7_row6_col0\" class=\"data row6 col0\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row6_col1\" class=\"data row6 col1\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row6_col2\" class=\"data row6 col2\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row6_col3\" class=\"data row6 col3\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row6_col4\" class=\"data row6 col4\" >0.572500</td>\n",
|
|||
|
" <td id=\"T_85cb7_row6_col5\" class=\"data row6 col5\" >0.575000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row6_col6\" class=\"data row6 col6\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row6_col7\" class=\"data row6 col7\" >0.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_85cb7_level0_row7\" class=\"row_heading level0 row7\" >naive_bayes</th>\n",
|
|||
|
" <td id=\"T_85cb7_row7_col0\" class=\"data row7 col0\" >0.445323</td>\n",
|
|||
|
" <td id=\"T_85cb7_row7_col1\" class=\"data row7 col1\" >0.416216</td>\n",
|
|||
|
" <td id=\"T_85cb7_row7_col2\" class=\"data row7 col2\" >0.988304</td>\n",
|
|||
|
" <td id=\"T_85cb7_row7_col3\" class=\"data row7 col3\" >0.905882</td>\n",
|
|||
|
" <td id=\"T_85cb7_row7_col4\" class=\"data row7 col4\" >0.468750</td>\n",
|
|||
|
" <td id=\"T_85cb7_row7_col5\" class=\"data row7 col5\" >0.420000</td>\n",
|
|||
|
" <td id=\"T_85cb7_row7_col6\" class=\"data row7 col6\" >0.613987</td>\n",
|
|||
|
" <td id=\"T_85cb7_row7_col7\" class=\"data row7 col7\" >0.570370</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x19f0e7df470>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 37,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
|||
|
" [\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" \"Accuracy_train\",\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_train\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]\n",
|
|||
|
"class_metrics.sort_values(\n",
|
|||
|
" by=\"Accuracy_test\", ascending=False\n",
|
|||
|
").style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Метрики: Точность (Precision), Полнота (Recall), Верность (Accuracy), F-мера (F1)\n",
|
|||
|
"\n",
|
|||
|
"- **Precision_train**: Точность на обучающем наборе данных.\n",
|
|||
|
"- **Precision_test**: Точность на тестовом наборе данных.\n",
|
|||
|
"- **Recall_train**: Полнота на обучающем наборе данных.\n",
|
|||
|
"- **Recall_test**: Полнота на тестовом наборе данных.\n",
|
|||
|
"- **Accuracy_train**: Верность (аккуратность) на обучающем наборе данных.\n",
|
|||
|
"- **Accuracy_test**: Верность (аккуратность) на тестовом наборе данных.\n",
|
|||
|
"- **F1_train**: F-мера на обучающем наборе данных.\n",
|
|||
|
"- **F1_test**: F-мера на тестовом наборе данных.\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"1. **Модели `decision_tree`, `gradient_boosting`, `random_forest`**:\n",
|
|||
|
" - Демонстрируют идеальные значения по всем метрикам на обучающих и тестовых наборах данных (Precision, Recall, Accuracy, F1-мера равны 1.0).\n",
|
|||
|
" - Указывает на то, что эти модели безошибочно классифицируют все примеры.\n",
|
|||
|
"\n",
|
|||
|
"2. **Модель `knn`**:\n",
|
|||
|
" - Показывает очень высокие значения метрик, близкие к 1.0, что указывает на высокую эффективность модели.\n",
|
|||
|
"\n",
|
|||
|
"3. **Модель `mlp`**:\n",
|
|||
|
" - Имеет немного более низкие значения Recall (0.999747) и F1-меры (0.997098) на тестовом наборе по сравнению с другими моделями, но остается высокоэффективной.\n",
|
|||
|
"\n",
|
|||
|
"4. **Модель `logistic`**:\n",
|
|||
|
" - Показывает хорошие значения метрик, но не идеальные, что может указывать на некоторую сложность в классификации определенных примеров.\n",
|
|||
|
"\n",
|
|||
|
"5. **Модель `ridge`**:\n",
|
|||
|
" - Имеет более низкие значения Precision (0.887292) и F1-меры (0.940281) по сравнению с другими моделями, но все еще демонстрирует высокую верность (Accuracy).\n",
|
|||
|
"\n",
|
|||
|
"6. **Модель `naive_bayes`**:\n",
|
|||
|
" - Показывает самые низкие значения метрик, особенно Precision (0.164340) и F1-меры (0.281237), что указывает на низкую эффективность модели в данной задаче классификации.\n",
|
|||
|
"\n",
|
|||
|
"В целом, большинство моделей демонстрируют высокую эффективность, но модель `naive_bayes` нуждается в улучшении или замене на более подходящую модель для данной задачи."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 38,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_9bf77_row0_col0, #T_9bf77_row0_col1, #T_9bf77_row2_col0, #T_9bf77_row2_col1 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row0_col2, #T_9bf77_row0_col3, #T_9bf77_row0_col4, #T_9bf77_row1_col2, #T_9bf77_row2_col3, #T_9bf77_row2_col4 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row1_col0 {\n",
|
|||
|
" background-color: #a2da37;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row1_col1 {\n",
|
|||
|
" background-color: #a5db36;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row1_col3, #T_9bf77_row1_col4 {\n",
|
|||
|
" background-color: #d8576b;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row2_col2, #T_9bf77_row3_col2, #T_9bf77_row4_col2 {\n",
|
|||
|
" background-color: #d9586a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row3_col0, #T_9bf77_row4_col0 {\n",
|
|||
|
" background-color: #95d840;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row3_col1, #T_9bf77_row4_col1 {\n",
|
|||
|
" background-color: #9bd93c;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row3_col3, #T_9bf77_row4_col3 {\n",
|
|||
|
" background-color: #d35171;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row3_col4, #T_9bf77_row4_col4 {\n",
|
|||
|
" background-color: #d45270;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row5_col0 {\n",
|
|||
|
" background-color: #86d549;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row5_col1 {\n",
|
|||
|
" background-color: #8ed645;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row5_col2 {\n",
|
|||
|
" background-color: #d7566c;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row5_col3 {\n",
|
|||
|
" background-color: #cc4778;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row5_col4 {\n",
|
|||
|
" background-color: #cc4977;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row6_col0 {\n",
|
|||
|
" background-color: #1e9d89;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row6_col1, #T_9bf77_row7_col0 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row6_col2 {\n",
|
|||
|
" background-color: #5901a5;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row6_col3 {\n",
|
|||
|
" background-color: #5302a3;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row6_col4 {\n",
|
|||
|
" background-color: #5801a4;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row7_col1 {\n",
|
|||
|
" background-color: #3dbc74;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9bf77_row7_col2, #T_9bf77_row7_col3, #T_9bf77_row7_col4 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_9bf77\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_9bf77_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_9bf77_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
|||
|
" <th id=\"T_9bf77_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
|||
|
" <th id=\"T_9bf77_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
|||
|
" <th id=\"T_9bf77_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_9bf77_level0_row0\" class=\"row_heading level0 row0\" >gradient_boosting</th>\n",
|
|||
|
" <td id=\"T_9bf77_row0_col0\" class=\"data row0 col0\" >0.995000</td>\n",
|
|||
|
" <td id=\"T_9bf77_row0_col1\" class=\"data row0 col1\" >0.994083</td>\n",
|
|||
|
" <td id=\"T_9bf77_row0_col2\" class=\"data row0 col2\" >0.999949</td>\n",
|
|||
|
" <td id=\"T_9bf77_row0_col3\" class=\"data row0 col3\" >0.989754</td>\n",
|
|||
|
" <td id=\"T_9bf77_row0_col4\" class=\"data row0 col4\" >0.989806</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_9bf77_level0_row1\" class=\"row_heading level0 row1\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_9bf77_row1_col0\" class=\"data row1 col0\" >0.985000</td>\n",
|
|||
|
" <td id=\"T_9bf77_row1_col1\" class=\"data row1 col1\" >0.982036</td>\n",
|
|||
|
" <td id=\"T_9bf77_row1_col2\" class=\"data row1 col2\" >0.998875</td>\n",
|
|||
|
" <td id=\"T_9bf77_row1_col3\" class=\"data row1 col3\" >0.969168</td>\n",
|
|||
|
" <td id=\"T_9bf77_row1_col4\" class=\"data row1 col4\" >0.969629</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_9bf77_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
|
|||
|
" <td id=\"T_9bf77_row2_col0\" class=\"data row2 col0\" >0.995000</td>\n",
|
|||
|
" <td id=\"T_9bf77_row2_col1\" class=\"data row2 col1\" >0.994083</td>\n",
|
|||
|
" <td id=\"T_9bf77_row2_col2\" class=\"data row2 col2\" >0.994118</td>\n",
|
|||
|
" <td id=\"T_9bf77_row2_col3\" class=\"data row2 col3\" >0.989754</td>\n",
|
|||
|
" <td id=\"T_9bf77_row2_col4\" class=\"data row2 col4\" >0.989806</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_9bf77_level0_row3\" class=\"row_heading level0 row3\" >logistic</th>\n",
|
|||
|
" <td id=\"T_9bf77_row3_col0\" class=\"data row3 col0\" >0.960000</td>\n",
|
|||
|
" <td id=\"T_9bf77_row3_col1\" class=\"data row3 col1\" >0.950617</td>\n",
|
|||
|
" <td id=\"T_9bf77_row3_col2\" class=\"data row3 col2\" >0.993453</td>\n",
|
|||
|
" <td id=\"T_9bf77_row3_col3\" class=\"data row3 col3\" >0.917141</td>\n",
|
|||
|
" <td id=\"T_9bf77_row3_col4\" class=\"data row3 col4\" >0.920306</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_9bf77_level0_row4\" class=\"row_heading level0 row4\" >ridge</th>\n",
|
|||
|
" <td id=\"T_9bf77_row4_col0\" class=\"data row4 col0\" >0.960000</td>\n",
|
|||
|
" <td id=\"T_9bf77_row4_col1\" class=\"data row4 col1\" >0.950617</td>\n",
|
|||
|
" <td id=\"T_9bf77_row4_col2\" class=\"data row4 col2\" >0.993350</td>\n",
|
|||
|
" <td id=\"T_9bf77_row4_col3\" class=\"data row4 col3\" >0.917141</td>\n",
|
|||
|
" <td id=\"T_9bf77_row4_col4\" class=\"data row4 col4\" >0.920306</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_9bf77_level0_row5\" class=\"row_heading level0 row5\" >knn</th>\n",
|
|||
|
" <td id=\"T_9bf77_row5_col0\" class=\"data row5 col0\" >0.925000</td>\n",
|
|||
|
" <td id=\"T_9bf77_row5_col1\" class=\"data row5 col1\" >0.907975</td>\n",
|
|||
|
" <td id=\"T_9bf77_row5_col2\" class=\"data row5 col2\" >0.984092</td>\n",
|
|||
|
" <td id=\"T_9bf77_row5_col3\" class=\"data row5 col3\" >0.844881</td>\n",
|
|||
|
" <td id=\"T_9bf77_row5_col4\" class=\"data row5 col4\" >0.847103</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_9bf77_level0_row6\" class=\"row_heading level0 row6\" >mlp</th>\n",
|
|||
|
" <td id=\"T_9bf77_row6_col0\" class=\"data row6 col0\" >0.575000</td>\n",
|
|||
|
" <td id=\"T_9bf77_row6_col1\" class=\"data row6 col1\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_9bf77_row6_col2\" class=\"data row6 col2\" >0.517340</td>\n",
|
|||
|
" <td id=\"T_9bf77_row6_col3\" class=\"data row6 col3\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_9bf77_row6_col4\" class=\"data row6 col4\" >0.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_9bf77_level0_row7\" class=\"row_heading level0 row7\" >naive_bayes</th>\n",
|
|||
|
" <td id=\"T_9bf77_row7_col0\" class=\"data row7 col0\" >0.420000</td>\n",
|
|||
|
" <td id=\"T_9bf77_row7_col1\" class=\"data row7 col1\" >0.570370</td>\n",
|
|||
|
" <td id=\"T_9bf77_row7_col2\" class=\"data row7 col2\" >0.483683</td>\n",
|
|||
|
" <td id=\"T_9bf77_row7_col3\" class=\"data row7 col3\" >-0.028825</td>\n",
|
|||
|
" <td id=\"T_9bf77_row7_col4\" class=\"data row7 col4\" >-0.062401</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x19f0f4355b0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Создаем DataFrame с метриками для каждой модели\n",
|
|||
|
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
|||
|
" [\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]\n",
|
|||
|
"\n",
|
|||
|
"# Сортировка по ROC_AUC_test в порядке убывания\n",
|
|||
|
"class_metrics_sorted = class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False)\n",
|
|||
|
"\n",
|
|||
|
"# Применение стилей\n",
|
|||
|
"styled_metrics = class_metrics_sorted.style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\", \n",
|
|||
|
" low=0.3, \n",
|
|||
|
" high=1, \n",
|
|||
|
" subset=[\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" ],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\", \n",
|
|||
|
" low=1, \n",
|
|||
|
" high=0.3, \n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"display(styled_metrics)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Метрики: Верность (Accuracy), F1-мера (F1), ROC-AUC, Каппа Коэна (Cohen's Kappa), Коэффициент корреляции Мэтьюса (MCC)\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"- **Accuracy_test**: Верность (аккуратность) на тестовом наборе данных.\n",
|
|||
|
"- **F1_test**: F1-мера на тестовом наборе данных.\n",
|
|||
|
"- **ROC_AUC_test**: Площадь под ROC-кривой на тестовом наборе данных.\n",
|
|||
|
"- **Cohen_kappa_test**: Каппа Коэна на тестовом наборе данных.\n",
|
|||
|
"- **MCC_test**: Коэффициент корреляции Мэтьюса на тестовом наборе данных.\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"1. **Модели `decision_tree`, `gradient_boosting`, `random_forest`**:\n",
|
|||
|
" - Демонстрируют идеальные значения по всем метрикам на тестовом наборе данных (Accuracy, F1, ROC AUC, Cohen's Kappa, MCC равны 1.0).\n",
|
|||
|
" - Указывает на то, что эти модели безошибочно классифицируют все примеры.\n",
|
|||
|
"\n",
|
|||
|
"2. **Модель `mip`**:\n",
|
|||
|
" - Показывает очень высокие значения метрик, близкие к 1.0, что указывает на высокую эффективность модели.\n",
|
|||
|
"\n",
|
|||
|
"3. **Модель `knn`**:\n",
|
|||
|
" - Имеет высокие значения метрик, близкие к 1.0, что указывает на высокую эффективность модели.\n",
|
|||
|
"\n",
|
|||
|
"4. **Модель `ridge`**:\n",
|
|||
|
" - Имеет более низкие значения Accuracy (0.984536) и F1-меры (0.940281) по сравнению с другими моделями, но все еще демонстрирует высокую верность (Accuracy) и ROC AUC.\n",
|
|||
|
"\n",
|
|||
|
"5. **Модель `logistic`**:\n",
|
|||
|
" - Показывает хорошие значения метрик, но не идеальные, что может указывать на некоторую сложность в классификации определенных примеров.\n",
|
|||
|
"\n",
|
|||
|
"6. **Модель `naive_bayes`**:\n",
|
|||
|
" - Показывает самые низкие значения метрик, особенно Accuracy (0.978846) и F1-меры (0.954733), что указывает на низкую эффективность модели в данной задаче классификации.\n",
|
|||
|
"\n",
|
|||
|
"В целом, большинство моделей демонстрируют высокую эффективность, но модель `naive_bayes` нуждается в улучшении или замене на более подходящую модель для данной задачи."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 39,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'decision_tree'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
|
|||
|
"\n",
|
|||
|
"display(best_model)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Вывод данных с ошибкой предсказания для оценки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 40,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'Error items count: 1'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>city</th>\n",
|
|||
|
" <th>Predicted</th>\n",
|
|||
|
" <th>state</th>\n",
|
|||
|
" <th>date_time</th>\n",
|
|||
|
" <th>shape</th>\n",
|
|||
|
" <th>text</th>\n",
|
|||
|
" <th>city_latitude</th>\n",
|
|||
|
" <th>city_longitude</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>421</th>\n",
|
|||
|
" <td>Shelbyville</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>IN</td>\n",
|
|||
|
" <td>2020-03-29T22:03:00</td>\n",
|
|||
|
" <td>sphere</td>\n",
|
|||
|
" <td>Standing by my truck outside my garage I saw o...</td>\n",
|
|||
|
" <td>39.5239</td>\n",
|
|||
|
" <td>-85.7853</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" city Predicted state date_time shape \\\n",
|
|||
|
"421 Shelbyville 0 IN 2020-03-29T22:03:00 sphere \n",
|
|||
|
"\n",
|
|||
|
" text city_latitude \\\n",
|
|||
|
"421 Standing by my truck outside my garage I saw o... 39.5239 \n",
|
|||
|
"\n",
|
|||
|
" city_longitude \n",
|
|||
|
"421 -85.7853 "
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Преобразование тестовых данных\n",
|
|||
|
"preprocessing_result = pipeline_end.transform(X_test)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Получение предсказаний лучшей модели\n",
|
|||
|
"y_pred = class_models[best_model][\"preds\"]\n",
|
|||
|
"\n",
|
|||
|
"# Нахождение индексов ошибок\n",
|
|||
|
"error_index = y_test[y_test != y_pred].index.tolist() # Убираем столбец \"above_average_city_latitude\"\n",
|
|||
|
"display(f\"Error items count: {len(error_index)}\")\n",
|
|||
|
"\n",
|
|||
|
"# Создание DataFrame с ошибочными объектами\n",
|
|||
|
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
|
|||
|
"error_df = X_test.loc[error_index].copy()\n",
|
|||
|
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
|
|||
|
"error_df = error_df.sort_index() # Сортировка по индексу\n",
|
|||
|
"\n",
|
|||
|
"# Вывод DataFrame с ошибочными объектами\n",
|
|||
|
"display(error_df)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Пример использования обученной модели (конвейера) для предсказания"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 41,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>city</th>\n",
|
|||
|
" <th>state</th>\n",
|
|||
|
" <th>date_time</th>\n",
|
|||
|
" <th>shape</th>\n",
|
|||
|
" <th>text</th>\n",
|
|||
|
" <th>city_latitude</th>\n",
|
|||
|
" <th>city_longitude</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>429</th>\n",
|
|||
|
" <td>Whiteford</td>\n",
|
|||
|
" <td>MD</td>\n",
|
|||
|
" <td>2020-03-30T20:50:00</td>\n",
|
|||
|
" <td>light</td>\n",
|
|||
|
" <td>Continuous single file objects dimly lit fly o...</td>\n",
|
|||
|
" <td>39.7015</td>\n",
|
|||
|
" <td>-76.3228</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" city state date_time shape \\\n",
|
|||
|
"429 Whiteford MD 2020-03-30T20:50:00 light \n",
|
|||
|
"\n",
|
|||
|
" text city_latitude \\\n",
|
|||
|
"429 Continuous single file objects dimly lit fly o... 39.7015 \n",
|
|||
|
"\n",
|
|||
|
" city_longitude \n",
|
|||
|
"429 -76.3228 "
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>city_latitude</th>\n",
|
|||
|
" <th>city_longitude</th>\n",
|
|||
|
" <th>shape_chevron</th>\n",
|
|||
|
" <th>shape_cigar</th>\n",
|
|||
|
" <th>shape_circle</th>\n",
|
|||
|
" <th>shape_cone</th>\n",
|
|||
|
" <th>shape_cross</th>\n",
|
|||
|
" <th>shape_cylinder</th>\n",
|
|||
|
" <th>shape_delta</th>\n",
|
|||
|
" <th>shape_diamond</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>shape_flash</th>\n",
|
|||
|
" <th>shape_formation</th>\n",
|
|||
|
" <th>shape_light</th>\n",
|
|||
|
" <th>shape_other</th>\n",
|
|||
|
" <th>shape_oval</th>\n",
|
|||
|
" <th>shape_rectangle</th>\n",
|
|||
|
" <th>shape_sphere</th>\n",
|
|||
|
" <th>shape_teardrop</th>\n",
|
|||
|
" <th>shape_triangle</th>\n",
|
|||
|
" <th>shape_unknown</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>429</th>\n",
|
|||
|
" <td>0.08815</td>\n",
|
|||
|
" <td>1.192047</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>1 rows × 23 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" city_latitude city_longitude shape_chevron shape_cigar shape_circle \\\n",
|
|||
|
"429 0.08815 1.192047 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" shape_cone shape_cross shape_cylinder shape_delta shape_diamond ... \\\n",
|
|||
|
"429 0.0 0.0 0.0 0.0 0.0 ... \n",
|
|||
|
"\n",
|
|||
|
" shape_flash shape_formation shape_light shape_other shape_oval \\\n",
|
|||
|
"429 0.0 0.0 1.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" shape_rectangle shape_sphere shape_teardrop shape_triangle \\\n",
|
|||
|
"429 0.0 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" shape_unknown \n",
|
|||
|
"429 0.0 \n",
|
|||
|
"\n",
|
|||
|
"[1 rows x 23 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"predicted: 1 (proba: [0. 1.])\n",
|
|||
|
"real: 1\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"model = class_models[best_model][\"pipeline\"]\n",
|
|||
|
"\n",
|
|||
|
"# Выбираем позиционный индекс объекта для анализа\n",
|
|||
|
"example_index = 13\n",
|
|||
|
"\n",
|
|||
|
"# Получаем исходные данные для объекта\n",
|
|||
|
"test = pd.DataFrame(X_test.iloc[example_index, :]).T\n",
|
|||
|
"display(test)\n",
|
|||
|
"\n",
|
|||
|
"# Получаем преобразованные данные для объекта\n",
|
|||
|
"test_preprocessed = pd.DataFrame(preprocessed_df.iloc[example_index, :]).T\n",
|
|||
|
"display(test_preprocessed)\n",
|
|||
|
"\n",
|
|||
|
"# Делаем предсказание\n",
|
|||
|
"result_proba = model.predict_proba(test)[0]\n",
|
|||
|
"result = model.predict(test)[0]\n",
|
|||
|
"\n",
|
|||
|
"# Получаем реальное значение\n",
|
|||
|
"real = int(y_test.iloc[example_index])\n",
|
|||
|
"\n",
|
|||
|
"# Выводим результаты\n",
|
|||
|
"print(f\"predicted: {result} (proba: {result_proba})\")\n",
|
|||
|
"print(f\"real: {real}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Подбор гиперпараметров методом поиска по сетке"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 42,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"C:\\Users\\tumvu\\AppData\\Roaming\\Python\\Python312\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
|
|||
|
" _data = np.array(data, dtype=dtype, copy=copy,\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"{'model__criterion': 'gini',\n",
|
|||
|
" 'model__max_depth': 10,\n",
|
|||
|
" 'model__max_features': 'sqrt',\n",
|
|||
|
" 'model__n_estimators': 10}"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 42,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.model_selection import GridSearchCV\n",
|
|||
|
"\n",
|
|||
|
"optimized_model_type = \"random_forest\"\n",
|
|||
|
"\n",
|
|||
|
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
|
|||
|
"\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" \"model__n_estimators\": [10, 50, 100],\n",
|
|||
|
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
|
|||
|
" \"model__max_depth\": [5, 7, 10],\n",
|
|||
|
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"gs_optomizer = GridSearchCV(\n",
|
|||
|
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
|
|||
|
")\n",
|
|||
|
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
|
|||
|
"gs_optomizer.best_params_"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"__Обучение модели с новыми гиперпараметрами__"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 43,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.ensemble import RandomForestClassifier\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"\n",
|
|||
|
"# Определяем числовые признаки\n",
|
|||
|
"numeric_features = X_train.select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
|
|||
|
"\n",
|
|||
|
"# Установка random_state\n",
|
|||
|
"random_state = 42\n",
|
|||
|
"\n",
|
|||
|
"# Определение трансформера\n",
|
|||
|
"pipeline_end = ColumnTransformer([\n",
|
|||
|
" ('numeric', StandardScaler(), numeric_features),\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Объявление модели\n",
|
|||
|
"optimized_model = RandomForestClassifier(\n",
|
|||
|
" random_state=random_state,\n",
|
|||
|
" criterion=\"gini\",\n",
|
|||
|
" max_depth=5,\n",
|
|||
|
" max_features=\"sqrt\",\n",
|
|||
|
" n_estimators=10,\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Создание пайплайна с корректными шагами\n",
|
|||
|
"result = {}\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели\n",
|
|||
|
"result[\"pipeline\"] = Pipeline([\n",
|
|||
|
" (\"pipeline\", pipeline_end),\n",
|
|||
|
" (\"model\", optimized_model)\n",
|
|||
|
"]).fit(X_train, y_train.values.ravel())\n",
|
|||
|
"\n",
|
|||
|
"# Прогнозирование и расчет метрик\n",
|
|||
|
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
|
|||
|
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
|
|||
|
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
|
|||
|
"\n",
|
|||
|
"# Метрики для оценки модели\n",
|
|||
|
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
|
|||
|
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Формирование данных для оценки старой и новой версии модели"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 44,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
|
|||
|
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
|||
|
" data=class_models[optimized_model_type]\n",
|
|||
|
")\n",
|
|||
|
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
|||
|
" data=result\n",
|
|||
|
")\n",
|
|||
|
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
|
|||
|
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Оценка параметров старой и новой модели"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 45,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_dd4f0_row0_col0, #T_dd4f0_row0_col1, #T_dd4f0_row1_col0, #T_dd4f0_row1_col1 {\n",
|
|||
|
" background-color: #440154;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_dd4f0_row0_col2, #T_dd4f0_row0_col3 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_dd4f0_row0_col4, #T_dd4f0_row0_col5, #T_dd4f0_row0_col6, #T_dd4f0_row0_col7 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_dd4f0_row1_col2, #T_dd4f0_row1_col3 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_dd4f0_row1_col4, #T_dd4f0_row1_col5, #T_dd4f0_row1_col6, #T_dd4f0_row1_col7 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_dd4f0\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_dd4f0_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
|||
|
" <th id=\"T_dd4f0_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
|||
|
" <th id=\"T_dd4f0_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
|||
|
" <th id=\"T_dd4f0_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
|||
|
" <th id=\"T_dd4f0_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
|||
|
" <th id=\"T_dd4f0_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_dd4f0_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
|||
|
" <th id=\"T_dd4f0_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"index_name level0\" >Name</th>\n",
|
|||
|
" <th class=\"blank col0\" > </th>\n",
|
|||
|
" <th class=\"blank col1\" > </th>\n",
|
|||
|
" <th class=\"blank col2\" > </th>\n",
|
|||
|
" <th class=\"blank col3\" > </th>\n",
|
|||
|
" <th class=\"blank col4\" > </th>\n",
|
|||
|
" <th class=\"blank col5\" > </th>\n",
|
|||
|
" <th class=\"blank col6\" > </th>\n",
|
|||
|
" <th class=\"blank col7\" > </th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_dd4f0_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
|||
|
" <td id=\"T_dd4f0_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_dd4f0_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_dd4f0_row0_col2\" class=\"data row0 col2\" >0.997076</td>\n",
|
|||
|
" <td id=\"T_dd4f0_row0_col3\" class=\"data row0 col3\" >0.964706</td>\n",
|
|||
|
" <td id=\"T_dd4f0_row0_col4\" class=\"data row0 col4\" >0.998750</td>\n",
|
|||
|
" <td id=\"T_dd4f0_row0_col5\" class=\"data row0 col5\" >0.985000</td>\n",
|
|||
|
" <td id=\"T_dd4f0_row0_col6\" class=\"data row0 col6\" >0.998536</td>\n",
|
|||
|
" <td id=\"T_dd4f0_row0_col7\" class=\"data row0 col7\" >0.982036</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_dd4f0_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
|||
|
" <td id=\"T_dd4f0_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_dd4f0_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_dd4f0_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_dd4f0_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_dd4f0_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_dd4f0_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_dd4f0_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_dd4f0_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x19f2ccaa360>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 45,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"optimized_metrics[\n",
|
|||
|
" [\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" \"Accuracy_train\",\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_train\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"].style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обе модели, как \"Old\", так и \"New\", демонстрируют идеальную производительность по всем ключевым метрикам: Precision, Recall, Accuracy и F1 как на обучающей (train), так и на тестовой (test) выборках. Все значения почти равны 1.000000, что указывает на отсутствие ошибок в классификации и максимальную точность."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 46,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_e7c56_row0_col0, #T_e7c56_row0_col1 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_e7c56_row0_col2, #T_e7c56_row0_col3, #T_e7c56_row0_col4 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_e7c56_row1_col0, #T_e7c56_row1_col1 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_e7c56_row1_col2, #T_e7c56_row1_col3, #T_e7c56_row1_col4 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_e7c56\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_e7c56_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_e7c56_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
|||
|
" <th id=\"T_e7c56_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
|||
|
" <th id=\"T_e7c56_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
|||
|
" <th id=\"T_e7c56_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"index_name level0\" >Name</th>\n",
|
|||
|
" <th class=\"blank col0\" > </th>\n",
|
|||
|
" <th class=\"blank col1\" > </th>\n",
|
|||
|
" <th class=\"blank col2\" > </th>\n",
|
|||
|
" <th class=\"blank col3\" > </th>\n",
|
|||
|
" <th class=\"blank col4\" > </th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_e7c56_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
|||
|
" <td id=\"T_e7c56_row0_col0\" class=\"data row0 col0\" >0.985000</td>\n",
|
|||
|
" <td id=\"T_e7c56_row0_col1\" class=\"data row0 col1\" >0.982036</td>\n",
|
|||
|
" <td id=\"T_e7c56_row0_col2\" class=\"data row0 col2\" >0.998875</td>\n",
|
|||
|
" <td id=\"T_e7c56_row0_col3\" class=\"data row0 col3\" >0.969168</td>\n",
|
|||
|
" <td id=\"T_e7c56_row0_col4\" class=\"data row0 col4\" >0.969629</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_e7c56_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
|||
|
" <td id=\"T_e7c56_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_e7c56_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_e7c56_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_e7c56_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_e7c56_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x19f0d2da0c0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 46,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"optimized_metrics[\n",
|
|||
|
" [\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"].style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" ],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обе модели, как \"Old\", так и \"New\", показали идеальные результаты по всем выбранным метрикам: Accuracy, F1, ROC AUC, Cohen's kappa и MCC. Все метрики почти равны значению 1.000000 как на тестовой выборке, что указывает на безошибочную классификацию и максимальную эффективность обеих моделей."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 47,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA6cAAAGsCAYAAAAhRNGaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABj4klEQVR4nO3dd3gU1f7H8c+mh1RCSQhFQk0QBAHlUgMCBkWkKSCJgCLeawFBUbAA0oyiKBdQaqRoEFRQQUSvImBABKTZkC6gVAUSQknZnd8f+bG6BjAhk2zh/Xqeee7NzOzZ74aYT86cM2cshmEYAgAAAADAibycXQAAAAAAAHROAQAAAABOR+cUAAAAAOB0dE4BAAAAAE5H5xQAAAAA4HR0TgEAAAAATkfnFAAAAADgdHROAQAAAABO5+PsAgAAKKwLFy4oOzvbtPb8/PwUEBBgWnsAABQGuZaHzikAwK1cuHBBMdcF6+hxq2ltRkVFaf/+/W4Z5AAA90au/YnOKQDArWRnZ+vocav2b75OoSFFvzsl44xNMY0OKDs72+1CHADg/si1P9E5BQC4pdAQL1NCHAAAV0Cu0TkFALgpq2GT1TCnHQAAnI1co3MKAHBTNhmyqegpbkYbAAAUFbnGo2QAAAAAAC6AkVMAgFuyySYzJi6Z0woAAEVDrtE5BQC4KathyGoUfeqSGW0AAFBU5BrTegEAAAAALoCRUwCAW2LhCACAJyHX6JwCANyUTYas13iIAwA8B7nGtF4AAAAAgAtg5BQA4JaY/gQA8CTkGiOnAAAAAAAXwMgpAMAtseQ+AMCTkGt0TgEAbsr2/5sZ7QAA4GzkGtN6AQAAAAAugJFTAIBbspq05L4ZbQAAUFTkGp1TAICbshp5mxntAADgbOQa03oBAAAAAC6AkVMAgFti4QgAgCch1+icAgDclE0WWWUxpR0AAJyNXGNaLwAAAADABTByCgBwSzYjbzOjHQAAnI1cY+QUAAAAAOACGDkFALglq0n35pjRBgAARUWu0TkFALgpQhwA4EnINab1AgAAAABcACOnAAC3ZDMsshkmLLlvQhsAABQVuUbnFADgppj+BADwJOQa03oBAAAAAC6AkVMAgFuyyktWE66xWk2oBQCAoiLX6JwCANyUYdK9OYYb35sDAPAc5BrTegEAAAAALoCRUwCAW2LhCACAJyHX6JwCANyU1fCS1TDh3hzDhGIAACgico1pvQAAAAAAF8DIKQDALdlkkc2Ea6w2ufElZgCAxyDXGDkFAAAAALgARk4BAG6JhSMAAJ6EXKNzCgBwU+YtHOG+058AAJ6DXGNaLwAAAADABTByCgBwS3kLRxR96pIZbQAAUFTkGiOnAIrJ3LlzZbFY9Msvv/zjuVWrVlW/fv2KvSZ4Fpu8ZDVhM2NlRAAAiopco3MKoJB+/PFHJSUlqWLFivL391d0dLQSExP1448/Ors0AAAu6+JF04CAAP3222/5jrdu3Vp169Z1QmUALqJzCqDAlixZooYNG2rlypW677779MYbb6h///5atWqVGjZsqA8++MDZJeIacnHhCDM2ANeOrKwsvfjii84uA8iHXOOeUwAFtHfvXt17772qVq2avvrqK5UrV85+7LHHHlPLli1177336rvvvlO1atWcWCmuFTaTpi6588PKARRegwYNNGvWLD399NOKjo52djmAHbnGyCmAAnr55Zd17tw5zZw506FjKklly5bVjBkzdPbsWU2YMOGybRiGoXHjxqlSpUoqVaqU2rRpw3RgAECJeuaZZ2S1Wgs0evr222+rUaNGCgwMVEREhHr16qVDhw7Zj0+ePFne3t46ffq0fd/EiRNlsVj0+OOP2/dZrVaFhIRo2LBhpn4WwNPQOQVQIMuWLVPVqlXVsmXLSx5v1aqVqlatquXLl1+2jZEjR2rEiBGqX7++Xn75ZVWrVk233nqrzp49W1xlw4NZDYtpW2F89dVX6tSpk6Kjo2WxWPThhx86HDcMQyNHjlSFChUUGBiodu3aaffu3Q7nnDx5UomJiQoNDVV4eLj69++vzMzMon5LABRATEyM+vTpo1mzZunw4cOXPW/8+PHq06ePatasqVdffVWDBw/WypUr1apVK3tntGXLlrLZbFq7dq39dWlpafLy8lJaWpp939atW5WZmalWrVoV2+eC+yPX6JwCKID09HQdPnxY9evXv+J5N9xwg3799VedOXMm37ETJ05owoQJ6tixoz7++GM98sgjSklJUb9+/fT7778XV+mA6c6ePav69evr9ddfv+TxCRMmaPLkyZo+fbo2bNigoKAgJSQk6MKFC/ZzLi4i9vnnn+vjjz/WV199pQcffLCkPgJwzXv22WeVm5url1566ZLHDxw4oFGjRmncuHFauHChHnroIY0cOVKrVq3Sr7/+qjfeeEOSVL9+fYWGhto7ooZhaO3aterevbu9Qyr92WFt3rx5yXxAoBBcKdfonAL4Rxc7myEhIVc87+LxjIyMfMe++OILZWdna+DAgbJY/ryiN3jwYPMKxTXFjOX2L26Fcdttt2ncuHHq2rVrvmOGYWjSpEl67rnn1LlzZ91www2aP3++Dh8+bL8SvWPHDn366aeaPXu2mjRpohYtWmjKlClauHDhFUdxAJinWrVquvfeezVz5kwdOXIk3/ElS5bIZrOpR48e+v333+1bVFSUatasqVWrVkmSvLy81KxZM3311VeS8v77/uOPPzR8+HAZhqH169dLyuuc1q1bV+Hh4SX2GeF+yDU6pwAK4GKn81Ijon91pU7sgQMHJEk1a9Z02F+uXDmVLl3ajDJxjbEZXqZtUt5Flb9uWVlZha5p//79Onr0qNq1a2ffFxYWpiZNmtj/SF2/fr3Cw8PVuHFj+znt2rWTl5eXNmzYUMTvCoCCeu6555Sbm3vJe093794twzBUs2ZNlStXzmHbsWOHjh8/bj+3ZcuW2rx5s86fP6+0tDRVqFBBDRs2VP369e0jqmvXrr3sbTHAReQaq/UCKICwsDBVqFBB33333RXP++6771SxYkWFhoaWUGWAeSpXruzw9ahRo/T8888Xqo2jR49KkiIjIx32R0ZG2o8dPXpU5cuXdzju4+OjiIgI+zkAil+1atWUlJSkmTNnavjw4Q7HbDabLBaLVqxYIW9v73yvDQ4Otv//Fi1aKCcnR+vXr1daWpq9E9qyZUulpaXp559/1okTJ+icosS5Y67ROQVQIHfccYdmzZqltWvXqkWLFvmOp6Wl6ZdfftG///3vS77+uuuuk5R3Nfqvj5o5ceKETp06VTxFw6NdzdSlS7eTt+T+oUOHHC6s+Pv7F7ltAK7tueee09tvv53v3tPq1avLMAzFxMSoVq1aV2zj5ptvlp+fn9LS0pSWlqYnn3xSUt5CgbNmzdLKlSvtXwNXQq4xrRdAAT355JMKDAzUv//9b/3xxx8Ox06ePKn//Oc/KlWqlD2U/65du3by9fXVlClTZBh/Pn9r0qRJxVk2PJhN5qxsaPv/9kJDQx22qwnxqKgoSdKxY8cc9h87dsx+LCoqymFKoCTl5ubq5MmT9nMAlIzq1asrKSlJM2bMcBjh6datm7y9vTV69GiHzJLy7sH7aw4GBATopptu0jvvvKODBw86jJyeP39ekydPVvXq1VWhQoWS+VBwW+QanVMABVSzZk3NmzdPu3fvVr169TRixAi9+eabGjlypOrVq6c9e/borbfeUvXq1S/5+nLlymno0KFavny57rjjDr3++ut64IEHNHfuXJUtW7aEPw1QPGJiYhQVFWUfKZHy7vnZsGGDmjZtKklq2rSpTp8+rc2bN9vP+fLLL2Wz2dSkSZMSrxm41j377LPKycnRzp077fuqV6+ucePGacGCBWrRooVefvllTZ8+XcOGDVPt2rU1Z84chzZatmypnTt3KiwsTPXq1ZMklS9fXrVr19auXbuY0gu3VdK5xrReAAV29913KzY2VsnJyUpJSdHvv/+uMmXKqE2bNnrmmWdUt27dK75+3LhxCggI0PTp07Vq1So1adJE//vf/9SxY8cS+gTwJDZ5yWbCNdbCtpGZmak9e/b
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1000x400 with 4 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False)\n",
|
|||
|
"\n",
|
|||
|
"for index in range(0, len(optimized_metrics)):\n",
|
|||
|
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[\"Below Average\", \"Above Average\"]\n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
" disp.ax_.set_title(optimized_metrics.index[index]) \n",
|
|||
|
"\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"В желтом квадрате мы видим значение 28511, что обозначает количество правильно классифицированных объектов, отнесенных к классу \"Below Average\". Это свидетельствует о том, что модель успешно идентифицирует объекты этого класса, минимизируя количество ложных положительных срабатываний.\n",
|
|||
|
"\n",
|
|||
|
"В зеленом квадрате значение 3952 указывает на количество правильно классифицированных объектов, отнесенных к классу \"Above Average\". Это также является показателем высокой точности модели в определении объектов данного класса."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Определение достижимого уровня качества модели для второй задачи (задача регрессии)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Загрузка данных и создание целевой переменной"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 48,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Среднее значение поля 'city_latitude': 38.70460772283549\n",
|
|||
|
" summary city state \\\n",
|
|||
|
"0 Viewed some red lights in the sky appearing to... Visalia CA \n",
|
|||
|
"1 Look like 1 or 3 crafts from North traveling s... Cincinnati OH \n",
|
|||
|
"2 seen dark rectangle moving slowly thru the sky... Tecopa CA \n",
|
|||
|
"3 One red light moving switly west to east, beco... Knoxville TN \n",
|
|||
|
"4 Bright, circular Fresnel-lens shaped light sev... Alexandria VA \n",
|
|||
|
"\n",
|
|||
|
" date_time shape duration \\\n",
|
|||
|
"0 2021-12-15T21:45:00 light 2 minutes \n",
|
|||
|
"1 2021-12-16T09:45:00 triangle 14 seconds \n",
|
|||
|
"2 2021-12-10T00:00:00 rectangle Several minutes \n",
|
|||
|
"3 2021-12-10T19:30:00 triangle 20-30 seconds \n",
|
|||
|
"4 2021-12-07T08:00:00 circle NaN \n",
|
|||
|
"\n",
|
|||
|
" stats \\\n",
|
|||
|
"0 Occurred : 12/15/2021 21:45 (Entered as : 12/... \n",
|
|||
|
"1 Occurred : 12/16/2021 09:45 (Entered as : 12/... \n",
|
|||
|
"2 Occurred : 12/10/2021 00:00 (Entered as : 12/... \n",
|
|||
|
"3 Occurred : 12/10/2021 19:30 (Entered as : 12/... \n",
|
|||
|
"4 Occurred : 12/7/2021 08:00 (Entered as : 12/0... \n",
|
|||
|
"\n",
|
|||
|
" report_link \\\n",
|
|||
|
"0 http://www.nuforc.org/webreports/165/S165881.html \n",
|
|||
|
"1 http://www.nuforc.org/webreports/165/S165888.html \n",
|
|||
|
"2 http://www.nuforc.org/webreports/165/S165810.html \n",
|
|||
|
"3 http://www.nuforc.org/webreports/165/S165825.html \n",
|
|||
|
"4 http://www.nuforc.org/webreports/165/S165754.html \n",
|
|||
|
"\n",
|
|||
|
" text posted \\\n",
|
|||
|
"0 Viewed some red lights in the sky appearing to... 2021-12-19T00:00:00 \n",
|
|||
|
"1 Look like 1 or 3 crafts from North traveling s... 2021-12-19T00:00:00 \n",
|
|||
|
"2 seen dark rectangle moving slowly thru the sky... 2021-12-19T00:00:00 \n",
|
|||
|
"3 One red light moving switly west to east, beco... 2021-12-19T00:00:00 \n",
|
|||
|
"4 Bright, circular Fresnel-lens shaped light sev... 2021-12-19T00:00:00 \n",
|
|||
|
"\n",
|
|||
|
" city_latitude city_longitude above_average_city_latitude \n",
|
|||
|
"0 36.356650 -119.347937 0 \n",
|
|||
|
"1 39.174503 -84.481363 1 \n",
|
|||
|
"2 NaN NaN 0 \n",
|
|||
|
"3 35.961561 -83.980115 0 \n",
|
|||
|
"4 38.798958 -77.095133 1 \n",
|
|||
|
"Статистическое описание DataFrame:\n",
|
|||
|
" city_latitude city_longitude above_average_city_latitude\n",
|
|||
|
"count 110136.000000 110136.000000 136940.000000\n",
|
|||
|
"mean 38.704608 -95.185792 0.435928\n",
|
|||
|
"std 5.752186 18.310088 0.495880\n",
|
|||
|
"min -32.055500 -170.494000 0.000000\n",
|
|||
|
"25% 34.238375 -113.901810 0.000000\n",
|
|||
|
"50% 39.257500 -89.161450 0.000000\n",
|
|||
|
"75% 42.317739 -80.363444 1.000000\n",
|
|||
|
"max 64.845276 130.850580 1.000000\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"from sklearn import set_config\n",
|
|||
|
"\n",
|
|||
|
"set_config(transform_output=\"pandas\")\n",
|
|||
|
"\n",
|
|||
|
"# Загрузка данных\n",
|
|||
|
"df = pd.read_csv(\"nuforc_reports.csv\")\n",
|
|||
|
"\n",
|
|||
|
"# Опция для настройки генерации случайных чисел \n",
|
|||
|
"random_state = 42\n",
|
|||
|
"\n",
|
|||
|
"# Вычисление среднего значения поля \"city_latitude\"\n",
|
|||
|
"average_city_latitude = df['city_latitude'].mean()\n",
|
|||
|
"print(f\"Среднее значение поля 'city_latitude': {average_city_latitude}\")\n",
|
|||
|
"\n",
|
|||
|
"# Создание новой колонки, указывающей, выше или ниже среднего значение цены\n",
|
|||
|
"df['above_average_city_latitude'] = (df['city_latitude'] > average_city_latitude).astype(int)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод DataFrame с новой колонкой\n",
|
|||
|
"print(df.head())\n",
|
|||
|
"\n",
|
|||
|
"# Примерный анализ данных\n",
|
|||
|
"print(\"Статистическое описание DataFrame:\")\n",
|
|||
|
"print(df.describe())"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи регрессии\n",
|
|||
|
"\n",
|
|||
|
"Целевой признак -- above_average_city_latitude"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 49,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>summary</th>\n",
|
|||
|
" <th>city</th>\n",
|
|||
|
" <th>state</th>\n",
|
|||
|
" <th>date_time</th>\n",
|
|||
|
" <th>shape</th>\n",
|
|||
|
" <th>duration</th>\n",
|
|||
|
" <th>stats</th>\n",
|
|||
|
" <th>report_link</th>\n",
|
|||
|
" <th>text</th>\n",
|
|||
|
" <th>posted</th>\n",
|
|||
|
" <th>city_latitude</th>\n",
|
|||
|
" <th>city_longitude</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>124678</th>\n",
|
|||
|
" <td>4 lights moving in unison at a high rate of sp...</td>\n",
|
|||
|
" <td>Mount Juliet</td>\n",
|
|||
|
" <td>TN</td>\n",
|
|||
|
" <td>2019-02-24T20:25:00</td>\n",
|
|||
|
" <td>light</td>\n",
|
|||
|
" <td>4 seconds</td>\n",
|
|||
|
" <td>Occurred : 2/24/2019 20:25 (Entered as : 02/2...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/144/S144994.html</td>\n",
|
|||
|
" <td>4 lights moving in unison at a high rate of sp...</td>\n",
|
|||
|
" <td>2019-02-27T00:00:00</td>\n",
|
|||
|
" <td>36.172148</td>\n",
|
|||
|
" <td>-86.490748</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2212</th>\n",
|
|||
|
" <td>Observed a long line of objects moving at very...</td>\n",
|
|||
|
" <td>Bethany Beach</td>\n",
|
|||
|
" <td>DE</td>\n",
|
|||
|
" <td>2020-03-01T05:30:00</td>\n",
|
|||
|
" <td>other</td>\n",
|
|||
|
" <td>45 seconds</td>\n",
|
|||
|
" <td>Occurred : 3/1/2020 05:30 (Entered as : 03/01...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/153/S153730.html</td>\n",
|
|||
|
" <td>Observed a long line of objects moving at very...</td>\n",
|
|||
|
" <td>2020-04-09T00:00:00</td>\n",
|
|||
|
" <td>38.556200</td>\n",
|
|||
|
" <td>-75.069200</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>31960</th>\n",
|
|||
|
" <td>My wife and myself were traveling to a friends...</td>\n",
|
|||
|
" <td>Webb</td>\n",
|
|||
|
" <td>MS</td>\n",
|
|||
|
" <td>1975-12-31T19:30:00</td>\n",
|
|||
|
" <td>disk</td>\n",
|
|||
|
" <td>approx 5 mins</td>\n",
|
|||
|
" <td>Occurred : 12/31/1975 19:30 (Entered as : 12/...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/037/S37355.html</td>\n",
|
|||
|
" <td>My wife and myself were traveling to a friends...</td>\n",
|
|||
|
" <td>2004-06-18T00:00:00</td>\n",
|
|||
|
" <td>33.919100</td>\n",
|
|||
|
" <td>-90.307300</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>33954</th>\n",
|
|||
|
" <td>I was sitting at my friends house, in my car a...</td>\n",
|
|||
|
" <td>Ellensburg</td>\n",
|
|||
|
" <td>WA</td>\n",
|
|||
|
" <td>2004-03-17T08:45:00</td>\n",
|
|||
|
" <td>diamond</td>\n",
|
|||
|
" <td>10-20 minutes</td>\n",
|
|||
|
" <td>Occurred : 3/17/2004 08:45 (Entered as : 03/1...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/035/S35995.html</td>\n",
|
|||
|
" <td>I was sitting at my friends house, in my car a...</td>\n",
|
|||
|
" <td>2004-04-09T00:00:00</td>\n",
|
|||
|
" <td>46.979000</td>\n",
|
|||
|
" <td>-120.470300</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>516</th>\n",
|
|||
|
" <td>Observed two glimmering craft over Powell Rive...</td>\n",
|
|||
|
" <td>Powell River</td>\n",
|
|||
|
" <td>BC</td>\n",
|
|||
|
" <td>2020-04-09T11:00:00</td>\n",
|
|||
|
" <td>disk</td>\n",
|
|||
|
" <td>3 minutes</td>\n",
|
|||
|
" <td>Occurred : 4/9/2020 11:00 (Entered as : 04/09...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/155/S155845.html</td>\n",
|
|||
|
" <td>Observed two glimmering craft over Powell Rive...</td>\n",
|
|||
|
" <td>2020-06-25T00:00:00</td>\n",
|
|||
|
" <td>50.016300</td>\n",
|
|||
|
" <td>-124.322600</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>110268</th>\n",
|
|||
|
" <td>Big black tranparent blib in sky.</td>\n",
|
|||
|
" <td>Sedro Woolley</td>\n",
|
|||
|
" <td>WA</td>\n",
|
|||
|
" <td>2016-08-14T19:45:00</td>\n",
|
|||
|
" <td>other</td>\n",
|
|||
|
" <td>6 seconds</td>\n",
|
|||
|
" <td>Occurred : 8/14/2016 19:45 (Entered as : 08/1...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/129/S129281.html</td>\n",
|
|||
|
" <td>Big black tranparent blib in sky Walking at No...</td>\n",
|
|||
|
" <td>2016-08-16T00:00:00</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>119879</th>\n",
|
|||
|
" <td>Very fast bright light.</td>\n",
|
|||
|
" <td>Brainerd</td>\n",
|
|||
|
" <td>MN</td>\n",
|
|||
|
" <td>2017-08-19T18:00:00</td>\n",
|
|||
|
" <td>light</td>\n",
|
|||
|
" <td>5 seconds</td>\n",
|
|||
|
" <td>Occurred : 8/19/2017 18:00 (Entered as : 08/1...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/135/S135876.html</td>\n",
|
|||
|
" <td>Very fast bright light. Observed a bright ligh...</td>\n",
|
|||
|
" <td>2017-08-24T00:00:00</td>\n",
|
|||
|
" <td>46.306700</td>\n",
|
|||
|
" <td>-94.100800</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>103694</th>\n",
|
|||
|
" <td>never seen this before</td>\n",
|
|||
|
" <td>Old Tappan</td>\n",
|
|||
|
" <td>NJ</td>\n",
|
|||
|
" <td>2015-03-12T05:49:00</td>\n",
|
|||
|
" <td>triangle</td>\n",
|
|||
|
" <td>20 seconds</td>\n",
|
|||
|
" <td>Occurred : 3/12/2015 05:49 (Entered as : 0312...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/117/S117744.html</td>\n",
|
|||
|
" <td>never seen this before. moved slow then all th...</td>\n",
|
|||
|
" <td>2015-03-13T00:00:00</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>131932</th>\n",
|
|||
|
" <td>Back in the late 70s or early 80s my family wa...</td>\n",
|
|||
|
" <td>Tacoma</td>\n",
|
|||
|
" <td>WA</td>\n",
|
|||
|
" <td>1980-06-01T09:00:00</td>\n",
|
|||
|
" <td>sphere</td>\n",
|
|||
|
" <td>10 seconds</td>\n",
|
|||
|
" <td>Occurred : 6/1/1980 09:00 (Entered as : 06/01...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/164/S164812.html</td>\n",
|
|||
|
" <td>Back in the late 70s or early 80s my family wa...</td>\n",
|
|||
|
" <td>2021-10-19T00:00:00</td>\n",
|
|||
|
" <td>47.212572</td>\n",
|
|||
|
" <td>-122.459720</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>121958</th>\n",
|
|||
|
" <td>Exiting Highway 61 onto I-280 and saw a light ...</td>\n",
|
|||
|
" <td>Davenport</td>\n",
|
|||
|
" <td>IA</td>\n",
|
|||
|
" <td>2018-11-01T01:00:00</td>\n",
|
|||
|
" <td>light</td>\n",
|
|||
|
" <td>2 seconds</td>\n",
|
|||
|
" <td>Occurred : 11/1/2018 01:00 (Entered as : 11/0...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/143/S143647.html</td>\n",
|
|||
|
" <td>Exiting highway 61 on to Interstate 280 and sa...</td>\n",
|
|||
|
" <td>2018-11-09T00:00:00</td>\n",
|
|||
|
" <td>41.555164</td>\n",
|
|||
|
" <td>-90.598760</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>109552 rows × 12 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" summary city \\\n",
|
|||
|
"124678 4 lights moving in unison at a high rate of sp... Mount Juliet \n",
|
|||
|
"2212 Observed a long line of objects moving at very... Bethany Beach \n",
|
|||
|
"31960 My wife and myself were traveling to a friends... Webb \n",
|
|||
|
"33954 I was sitting at my friends house, in my car a... Ellensburg \n",
|
|||
|
"516 Observed two glimmering craft over Powell Rive... Powell River \n",
|
|||
|
"... ... ... \n",
|
|||
|
"110268 Big black tranparent blib in sky. Sedro Woolley \n",
|
|||
|
"119879 Very fast bright light. Brainerd \n",
|
|||
|
"103694 never seen this before Old Tappan \n",
|
|||
|
"131932 Back in the late 70s or early 80s my family wa... Tacoma \n",
|
|||
|
"121958 Exiting Highway 61 onto I-280 and saw a light ... Davenport \n",
|
|||
|
"\n",
|
|||
|
" state date_time shape duration \\\n",
|
|||
|
"124678 TN 2019-02-24T20:25:00 light 4 seconds \n",
|
|||
|
"2212 DE 2020-03-01T05:30:00 other 45 seconds \n",
|
|||
|
"31960 MS 1975-12-31T19:30:00 disk approx 5 mins \n",
|
|||
|
"33954 WA 2004-03-17T08:45:00 diamond 10-20 minutes \n",
|
|||
|
"516 BC 2020-04-09T11:00:00 disk 3 minutes \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"110268 WA 2016-08-14T19:45:00 other 6 seconds \n",
|
|||
|
"119879 MN 2017-08-19T18:00:00 light 5 seconds \n",
|
|||
|
"103694 NJ 2015-03-12T05:49:00 triangle 20 seconds \n",
|
|||
|
"131932 WA 1980-06-01T09:00:00 sphere 10 seconds \n",
|
|||
|
"121958 IA 2018-11-01T01:00:00 light 2 seconds \n",
|
|||
|
"\n",
|
|||
|
" stats \\\n",
|
|||
|
"124678 Occurred : 2/24/2019 20:25 (Entered as : 02/2... \n",
|
|||
|
"2212 Occurred : 3/1/2020 05:30 (Entered as : 03/01... \n",
|
|||
|
"31960 Occurred : 12/31/1975 19:30 (Entered as : 12/... \n",
|
|||
|
"33954 Occurred : 3/17/2004 08:45 (Entered as : 03/1... \n",
|
|||
|
"516 Occurred : 4/9/2020 11:00 (Entered as : 04/09... \n",
|
|||
|
"... ... \n",
|
|||
|
"110268 Occurred : 8/14/2016 19:45 (Entered as : 08/1... \n",
|
|||
|
"119879 Occurred : 8/19/2017 18:00 (Entered as : 08/1... \n",
|
|||
|
"103694 Occurred : 3/12/2015 05:49 (Entered as : 0312... \n",
|
|||
|
"131932 Occurred : 6/1/1980 09:00 (Entered as : 06/01... \n",
|
|||
|
"121958 Occurred : 11/1/2018 01:00 (Entered as : 11/0... \n",
|
|||
|
"\n",
|
|||
|
" report_link \\\n",
|
|||
|
"124678 http://www.nuforc.org/webreports/144/S144994.html \n",
|
|||
|
"2212 http://www.nuforc.org/webreports/153/S153730.html \n",
|
|||
|
"31960 http://www.nuforc.org/webreports/037/S37355.html \n",
|
|||
|
"33954 http://www.nuforc.org/webreports/035/S35995.html \n",
|
|||
|
"516 http://www.nuforc.org/webreports/155/S155845.html \n",
|
|||
|
"... ... \n",
|
|||
|
"110268 http://www.nuforc.org/webreports/129/S129281.html \n",
|
|||
|
"119879 http://www.nuforc.org/webreports/135/S135876.html \n",
|
|||
|
"103694 http://www.nuforc.org/webreports/117/S117744.html \n",
|
|||
|
"131932 http://www.nuforc.org/webreports/164/S164812.html \n",
|
|||
|
"121958 http://www.nuforc.org/webreports/143/S143647.html \n",
|
|||
|
"\n",
|
|||
|
" text \\\n",
|
|||
|
"124678 4 lights moving in unison at a high rate of sp... \n",
|
|||
|
"2212 Observed a long line of objects moving at very... \n",
|
|||
|
"31960 My wife and myself were traveling to a friends... \n",
|
|||
|
"33954 I was sitting at my friends house, in my car a... \n",
|
|||
|
"516 Observed two glimmering craft over Powell Rive... \n",
|
|||
|
"... ... \n",
|
|||
|
"110268 Big black tranparent blib in sky Walking at No... \n",
|
|||
|
"119879 Very fast bright light. Observed a bright ligh... \n",
|
|||
|
"103694 never seen this before. moved slow then all th... \n",
|
|||
|
"131932 Back in the late 70s or early 80s my family wa... \n",
|
|||
|
"121958 Exiting highway 61 on to Interstate 280 and sa... \n",
|
|||
|
"\n",
|
|||
|
" posted city_latitude city_longitude \n",
|
|||
|
"124678 2019-02-27T00:00:00 36.172148 -86.490748 \n",
|
|||
|
"2212 2020-04-09T00:00:00 38.556200 -75.069200 \n",
|
|||
|
"31960 2004-06-18T00:00:00 33.919100 -90.307300 \n",
|
|||
|
"33954 2004-04-09T00:00:00 46.979000 -120.470300 \n",
|
|||
|
"516 2020-06-25T00:00:00 50.016300 -124.322600 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"110268 2016-08-16T00:00:00 NaN NaN \n",
|
|||
|
"119879 2017-08-24T00:00:00 46.306700 -94.100800 \n",
|
|||
|
"103694 2015-03-13T00:00:00 NaN NaN \n",
|
|||
|
"131932 2021-10-19T00:00:00 47.212572 -122.459720 \n",
|
|||
|
"121958 2018-11-09T00:00:00 41.555164 -90.598760 \n",
|
|||
|
"\n",
|
|||
|
"[109552 rows x 12 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>above_average_city_latitude</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>124678</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2212</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>31960</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>33954</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>516</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>110268</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>119879</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>103694</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>131932</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>121958</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>109552 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" above_average_city_latitude\n",
|
|||
|
"124678 0\n",
|
|||
|
"2212 0\n",
|
|||
|
"31960 0\n",
|
|||
|
"33954 1\n",
|
|||
|
"516 1\n",
|
|||
|
"... ...\n",
|
|||
|
"110268 0\n",
|
|||
|
"119879 1\n",
|
|||
|
"103694 0\n",
|
|||
|
"131932 1\n",
|
|||
|
"121958 1\n",
|
|||
|
"\n",
|
|||
|
"[109552 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>summary</th>\n",
|
|||
|
" <th>city</th>\n",
|
|||
|
" <th>state</th>\n",
|
|||
|
" <th>date_time</th>\n",
|
|||
|
" <th>shape</th>\n",
|
|||
|
" <th>duration</th>\n",
|
|||
|
" <th>stats</th>\n",
|
|||
|
" <th>report_link</th>\n",
|
|||
|
" <th>text</th>\n",
|
|||
|
" <th>posted</th>\n",
|
|||
|
" <th>city_latitude</th>\n",
|
|||
|
" <th>city_longitude</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>108340</th>\n",
|
|||
|
" <td>Bright light with irregular flight pattern and...</td>\n",
|
|||
|
" <td>Pittsboro</td>\n",
|
|||
|
" <td>NC</td>\n",
|
|||
|
" <td>2015-12-08T21:00:00</td>\n",
|
|||
|
" <td>light</td>\n",
|
|||
|
" <td>5 minutes</td>\n",
|
|||
|
" <td>Occurred : 12/8/2015 21:00 (Entered as : 12/0...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/124/S124610.html</td>\n",
|
|||
|
" <td>Bright light with irregular flight pattern and...</td>\n",
|
|||
|
" <td>2015-12-17T00:00:00</td>\n",
|
|||
|
" <td>35.751900</td>\n",
|
|||
|
" <td>-79.224800</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>29071</th>\n",
|
|||
|
" <td>3 lights flash less than seconds apart in a ro...</td>\n",
|
|||
|
" <td>San Luis Obispo</td>\n",
|
|||
|
" <td>CA</td>\n",
|
|||
|
" <td>2004-03-16T22:10:00</td>\n",
|
|||
|
" <td>flash</td>\n",
|
|||
|
" <td>3 seconds</td>\n",
|
|||
|
" <td>Occurred : 3/16/2004 22:10 (Entered as : 03/1...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/035/S35640.html</td>\n",
|
|||
|
" <td>3 lights flash less than seconds apart in a ro...</td>\n",
|
|||
|
" <td>2004-03-17T00:00:00</td>\n",
|
|||
|
" <td>35.262867</td>\n",
|
|||
|
" <td>-120.624789</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>89462</th>\n",
|
|||
|
" <td>A wonderful aircraft, but spooky.</td>\n",
|
|||
|
" <td>Cornwall</td>\n",
|
|||
|
" <td>ON</td>\n",
|
|||
|
" <td>2013-07-06T00:30:00</td>\n",
|
|||
|
" <td>triangle</td>\n",
|
|||
|
" <td>20 seconds</td>\n",
|
|||
|
" <td>Occurred : 7/6/2013 00:30 (Entered as : 76201...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/099/S99896.html</td>\n",
|
|||
|
" <td>A wonderful aircraft , but spooky Saw aircraft...</td>\n",
|
|||
|
" <td>2013-07-14T00:00:00</td>\n",
|
|||
|
" <td>45.056209</td>\n",
|
|||
|
" <td>-74.710143</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>124422</th>\n",
|
|||
|
" <td>I was outside facing east, speaking with a co-...</td>\n",
|
|||
|
" <td>Virginia Beach</td>\n",
|
|||
|
" <td>VA</td>\n",
|
|||
|
" <td>2019-01-09T06:35:00</td>\n",
|
|||
|
" <td>fireball</td>\n",
|
|||
|
" <td>5-7 seconds</td>\n",
|
|||
|
" <td>Occurred : 1/9/2019 06:35 (Entered as : 1-9-1...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/144/S144555.html</td>\n",
|
|||
|
" <td>I was outside facing east, speaking with a co-...</td>\n",
|
|||
|
" <td>2019-01-24T00:00:00</td>\n",
|
|||
|
" <td>36.837301</td>\n",
|
|||
|
" <td>-76.061948</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>126342</th>\n",
|
|||
|
" <td>Bright circular light in the sky - once we sta...</td>\n",
|
|||
|
" <td>Fruitland</td>\n",
|
|||
|
" <td>UT</td>\n",
|
|||
|
" <td>2019-05-11T23:40:00</td>\n",
|
|||
|
" <td>circle</td>\n",
|
|||
|
" <td>2 minutes</td>\n",
|
|||
|
" <td>Occurred : 5/11/2019 23:40 (Entered as : 5/11...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/146/S146570.html</td>\n",
|
|||
|
" <td>Bright circular light in the sky - once we sta...</td>\n",
|
|||
|
" <td>2019-06-07T00:00:00</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" <td>NaN</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>48714</th>\n",
|
|||
|
" <td>Was this white cylinder with fins a weather ba...</td>\n",
|
|||
|
" <td>Cape Coral</td>\n",
|
|||
|
" <td>FL</td>\n",
|
|||
|
" <td>2007-04-29T15:30:00</td>\n",
|
|||
|
" <td>cylinder</td>\n",
|
|||
|
" <td>30 mins</td>\n",
|
|||
|
" <td>Occurred : 4/29/2007 15:30 (Entered as : 04/2...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/056/S56462.html</td>\n",
|
|||
|
" <td>Was this white cylinder with fins a weather ba...</td>\n",
|
|||
|
" <td>2007-06-12T00:00:00</td>\n",
|
|||
|
" <td>26.616422</td>\n",
|
|||
|
" <td>-81.970066</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>25512</th>\n",
|
|||
|
" <td>I was driving down a rural road near Haymarket...</td>\n",
|
|||
|
" <td>Haymarket</td>\n",
|
|||
|
" <td>VA</td>\n",
|
|||
|
" <td>2003-03-11T22:00:00</td>\n",
|
|||
|
" <td>triangle</td>\n",
|
|||
|
" <td>3 minutes</td>\n",
|
|||
|
" <td>Occurred : 3/11/2003 22:00 (Entered as : 3-11...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/028/S28125.html</td>\n",
|
|||
|
" <td>I was driving down a rural road near Haymarket...</td>\n",
|
|||
|
" <td>2003-03-21T00:00:00</td>\n",
|
|||
|
" <td>38.869400</td>\n",
|
|||
|
" <td>-77.637300</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>96155</th>\n",
|
|||
|
" <td>Five orange lights in the sky over Essex area,...</td>\n",
|
|||
|
" <td>Essex Junction</td>\n",
|
|||
|
" <td>VT</td>\n",
|
|||
|
" <td>2014-05-25T21:50:00</td>\n",
|
|||
|
" <td>light</td>\n",
|
|||
|
" <td>10 minutes</td>\n",
|
|||
|
" <td>Occurred : 5/25/2014 21:50 (Entered as : 05/2...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/109/S109718.html</td>\n",
|
|||
|
" <td>Five orange lights in the sky over Essex area,...</td>\n",
|
|||
|
" <td>2014-06-04T00:00:00</td>\n",
|
|||
|
" <td>44.532199</td>\n",
|
|||
|
" <td>-73.058631</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>82188</th>\n",
|
|||
|
" <td>It was slowly moving over our pastures going a...</td>\n",
|
|||
|
" <td>New Washington</td>\n",
|
|||
|
" <td>OH</td>\n",
|
|||
|
" <td>2013-01-24T21:00:00</td>\n",
|
|||
|
" <td>circle</td>\n",
|
|||
|
" <td>5 minutes</td>\n",
|
|||
|
" <td>Occurred : 1/24/2013 21:00 (Entered as : 1/24...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/096/S96095.html</td>\n",
|
|||
|
" <td>It was slowly moving over our pastures going a...</td>\n",
|
|||
|
" <td>2013-02-04T00:00:00</td>\n",
|
|||
|
" <td>40.945000</td>\n",
|
|||
|
" <td>-82.861800</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>124520</th>\n",
|
|||
|
" <td>2 circles in the sky changed colors several ti...</td>\n",
|
|||
|
" <td>South Amboy</td>\n",
|
|||
|
" <td>NJ</td>\n",
|
|||
|
" <td>2019-02-02T00:28:00</td>\n",
|
|||
|
" <td>circle</td>\n",
|
|||
|
" <td>1 minute</td>\n",
|
|||
|
" <td>Occurred : 2/2/2019 00:28 (Entered as : 02/2/...</td>\n",
|
|||
|
" <td>http://www.nuforc.org/webreports/144/S144752.html</td>\n",
|
|||
|
" <td>2 circles in the sky changed colors several ti...</td>\n",
|
|||
|
" <td>2019-02-07T00:00:00</td>\n",
|
|||
|
" <td>40.477900</td>\n",
|
|||
|
" <td>-74.290700</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>27388 rows × 12 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" summary city \\\n",
|
|||
|
"108340 Bright light with irregular flight pattern and... Pittsboro \n",
|
|||
|
"29071 3 lights flash less than seconds apart in a ro... San Luis Obispo \n",
|
|||
|
"89462 A wonderful aircraft, but spooky. Cornwall \n",
|
|||
|
"124422 I was outside facing east, speaking with a co-... Virginia Beach \n",
|
|||
|
"126342 Bright circular light in the sky - once we sta... Fruitland \n",
|
|||
|
"... ... ... \n",
|
|||
|
"48714 Was this white cylinder with fins a weather ba... Cape Coral \n",
|
|||
|
"25512 I was driving down a rural road near Haymarket... Haymarket \n",
|
|||
|
"96155 Five orange lights in the sky over Essex area,... Essex Junction \n",
|
|||
|
"82188 It was slowly moving over our pastures going a... New Washington \n",
|
|||
|
"124520 2 circles in the sky changed colors several ti... South Amboy \n",
|
|||
|
"\n",
|
|||
|
" state date_time shape duration \\\n",
|
|||
|
"108340 NC 2015-12-08T21:00:00 light 5 minutes \n",
|
|||
|
"29071 CA 2004-03-16T22:10:00 flash 3 seconds \n",
|
|||
|
"89462 ON 2013-07-06T00:30:00 triangle 20 seconds \n",
|
|||
|
"124422 VA 2019-01-09T06:35:00 fireball 5-7 seconds \n",
|
|||
|
"126342 UT 2019-05-11T23:40:00 circle 2 minutes \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"48714 FL 2007-04-29T15:30:00 cylinder 30 mins \n",
|
|||
|
"25512 VA 2003-03-11T22:00:00 triangle 3 minutes \n",
|
|||
|
"96155 VT 2014-05-25T21:50:00 light 10 minutes \n",
|
|||
|
"82188 OH 2013-01-24T21:00:00 circle 5 minutes \n",
|
|||
|
"124520 NJ 2019-02-02T00:28:00 circle 1 minute \n",
|
|||
|
"\n",
|
|||
|
" stats \\\n",
|
|||
|
"108340 Occurred : 12/8/2015 21:00 (Entered as : 12/0... \n",
|
|||
|
"29071 Occurred : 3/16/2004 22:10 (Entered as : 03/1... \n",
|
|||
|
"89462 Occurred : 7/6/2013 00:30 (Entered as : 76201... \n",
|
|||
|
"124422 Occurred : 1/9/2019 06:35 (Entered as : 1-9-1... \n",
|
|||
|
"126342 Occurred : 5/11/2019 23:40 (Entered as : 5/11... \n",
|
|||
|
"... ... \n",
|
|||
|
"48714 Occurred : 4/29/2007 15:30 (Entered as : 04/2... \n",
|
|||
|
"25512 Occurred : 3/11/2003 22:00 (Entered as : 3-11... \n",
|
|||
|
"96155 Occurred : 5/25/2014 21:50 (Entered as : 05/2... \n",
|
|||
|
"82188 Occurred : 1/24/2013 21:00 (Entered as : 1/24... \n",
|
|||
|
"124520 Occurred : 2/2/2019 00:28 (Entered as : 02/2/... \n",
|
|||
|
"\n",
|
|||
|
" report_link \\\n",
|
|||
|
"108340 http://www.nuforc.org/webreports/124/S124610.html \n",
|
|||
|
"29071 http://www.nuforc.org/webreports/035/S35640.html \n",
|
|||
|
"89462 http://www.nuforc.org/webreports/099/S99896.html \n",
|
|||
|
"124422 http://www.nuforc.org/webreports/144/S144555.html \n",
|
|||
|
"126342 http://www.nuforc.org/webreports/146/S146570.html \n",
|
|||
|
"... ... \n",
|
|||
|
"48714 http://www.nuforc.org/webreports/056/S56462.html \n",
|
|||
|
"25512 http://www.nuforc.org/webreports/028/S28125.html \n",
|
|||
|
"96155 http://www.nuforc.org/webreports/109/S109718.html \n",
|
|||
|
"82188 http://www.nuforc.org/webreports/096/S96095.html \n",
|
|||
|
"124520 http://www.nuforc.org/webreports/144/S144752.html \n",
|
|||
|
"\n",
|
|||
|
" text \\\n",
|
|||
|
"108340 Bright light with irregular flight pattern and... \n",
|
|||
|
"29071 3 lights flash less than seconds apart in a ro... \n",
|
|||
|
"89462 A wonderful aircraft , but spooky Saw aircraft... \n",
|
|||
|
"124422 I was outside facing east, speaking with a co-... \n",
|
|||
|
"126342 Bright circular light in the sky - once we sta... \n",
|
|||
|
"... ... \n",
|
|||
|
"48714 Was this white cylinder with fins a weather ba... \n",
|
|||
|
"25512 I was driving down a rural road near Haymarket... \n",
|
|||
|
"96155 Five orange lights in the sky over Essex area,... \n",
|
|||
|
"82188 It was slowly moving over our pastures going a... \n",
|
|||
|
"124520 2 circles in the sky changed colors several ti... \n",
|
|||
|
"\n",
|
|||
|
" posted city_latitude city_longitude \n",
|
|||
|
"108340 2015-12-17T00:00:00 35.751900 -79.224800 \n",
|
|||
|
"29071 2004-03-17T00:00:00 35.262867 -120.624789 \n",
|
|||
|
"89462 2013-07-14T00:00:00 45.056209 -74.710143 \n",
|
|||
|
"124422 2019-01-24T00:00:00 36.837301 -76.061948 \n",
|
|||
|
"126342 2019-06-07T00:00:00 NaN NaN \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"48714 2007-06-12T00:00:00 26.616422 -81.970066 \n",
|
|||
|
"25512 2003-03-21T00:00:00 38.869400 -77.637300 \n",
|
|||
|
"96155 2014-06-04T00:00:00 44.532199 -73.058631 \n",
|
|||
|
"82188 2013-02-04T00:00:00 40.945000 -82.861800 \n",
|
|||
|
"124520 2019-02-07T00:00:00 40.477900 -74.290700 \n",
|
|||
|
"\n",
|
|||
|
"[27388 rows x 12 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>above_average_city_latitude</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>108340</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>29071</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>89462</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>124422</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>126342</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>48714</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>25512</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>96155</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>82188</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>124520</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>27388 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" above_average_city_latitude\n",
|
|||
|
"108340 0\n",
|
|||
|
"29071 0\n",
|
|||
|
"89462 1\n",
|
|||
|
"124422 0\n",
|
|||
|
"126342 0\n",
|
|||
|
"... ...\n",
|
|||
|
"48714 0\n",
|
|||
|
"25512 1\n",
|
|||
|
"96155 1\n",
|
|||
|
"82188 1\n",
|
|||
|
"124520 1\n",
|
|||
|
"\n",
|
|||
|
"[27388 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from typing import Tuple\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"from pandas import DataFrame\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"\n",
|
|||
|
"def split_into_train_test(\n",
|
|||
|
" df_input: DataFrame,\n",
|
|||
|
" target_colname: str = \"above_average_city_latitude\", \n",
|
|||
|
" frac_train: float = 0.8,\n",
|
|||
|
" random_state: int = None,\n",
|
|||
|
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
|||
|
" \n",
|
|||
|
" if not (0 < frac_train < 1):\n",
|
|||
|
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
|
|||
|
" \n",
|
|||
|
" # Проверка наличия целевого признака\n",
|
|||
|
" if target_colname not in df_input.columns:\n",
|
|||
|
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
|
|||
|
" \n",
|
|||
|
" # Разделяем данные на признаки и целевую переменную\n",
|
|||
|
" X = df_input.drop(columns=[target_colname]) # Признаки\n",
|
|||
|
" y = df_input[[target_colname]] # Целевая переменная\n",
|
|||
|
"\n",
|
|||
|
" # Разделяем данные на обучающую и тестовую выборки\n",
|
|||
|
" X_train, X_test, y_train, y_test = train_test_split(\n",
|
|||
|
" X, y,\n",
|
|||
|
" test_size=(1.0 - frac_train),\n",
|
|||
|
" random_state=random_state\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" return X_train, X_test, y_train, y_test\n",
|
|||
|
"\n",
|
|||
|
"# Применение функции для разделения данных\n",
|
|||
|
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
|
|||
|
" df, \n",
|
|||
|
" target_colname=\"above_average_city_latitude\", \n",
|
|||
|
" frac_train=0.8, \n",
|
|||
|
" random_state=42 \n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Для отображения результатов\n",
|
|||
|
"display(\"X_train\", X_train)\n",
|
|||
|
"display(\"y_train\", y_train)\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_test\", X_test)\n",
|
|||
|
"display(\"y_test\", y_test)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Формирование конвейера для решения задачи регрессии"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 50,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" city_latitude city_longitude shape_changing shape_chevron \\\n",
|
|||
|
"0 -0.475704 -1.527173 0.0 0.0 \n",
|
|||
|
"1 0.070048 0.574031 0.0 0.0 \n",
|
|||
|
"2 0.086123 0.291990 0.0 0.0 \n",
|
|||
|
"3 -0.552224 0.604238 0.0 0.0 \n",
|
|||
|
"4 -0.002686 1.019156 0.0 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"136935 1.426621 -1.516027 0.0 0.0 \n",
|
|||
|
"136936 -2.044798 0.757950 0.0 0.0 \n",
|
|||
|
"136937 0.180564 -0.660057 0.0 0.0 \n",
|
|||
|
"136938 -0.719495 -1.453787 0.0 0.0 \n",
|
|||
|
"136939 0.785101 1.328590 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" shape_cigar shape_circle shape_cone shape_crescent shape_cross \\\n",
|
|||
|
"0 0.0 0.0 0.0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 0.0 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 0.0 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 0.0 0.0 0.0 \n",
|
|||
|
"4 0.0 1.0 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"136935 0.0 0.0 0.0 0.0 0.0 \n",
|
|||
|
"136936 0.0 0.0 0.0 0.0 0.0 \n",
|
|||
|
"136937 0.0 0.0 0.0 0.0 0.0 \n",
|
|||
|
"136938 0.0 0.0 0.0 0.0 0.0 \n",
|
|||
|
"136939 0.0 0.0 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" shape_cylinder ... shape_light shape_other shape_oval \\\n",
|
|||
|
"0 0.0 ... 1.0 0.0 0.0 \n",
|
|||
|
"1 0.0 ... 0.0 0.0 0.0 \n",
|
|||
|
"2 0.0 ... 0.0 0.0 0.0 \n",
|
|||
|
"3 0.0 ... 0.0 0.0 0.0 \n",
|
|||
|
"4 0.0 ... 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"136935 0.0 ... 0.0 0.0 0.0 \n",
|
|||
|
"136936 1.0 ... 0.0 0.0 0.0 \n",
|
|||
|
"136937 0.0 ... 1.0 0.0 0.0 \n",
|
|||
|
"136938 0.0 ... 0.0 0.0 0.0 \n",
|
|||
|
"136939 0.0 ... 0.0 0.0 1.0 \n",
|
|||
|
"\n",
|
|||
|
" shape_pyramid shape_rectangle shape_round shape_sphere \\\n",
|
|||
|
"0 0.0 0.0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 0.0 0.0 \n",
|
|||
|
"2 0.0 1.0 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"136935 0.0 0.0 0.0 0.0 \n",
|
|||
|
"136936 0.0 0.0 0.0 0.0 \n",
|
|||
|
"136937 0.0 0.0 0.0 0.0 \n",
|
|||
|
"136938 0.0 0.0 0.0 0.0 \n",
|
|||
|
"136939 0.0 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" shape_teardrop shape_triangle shape_unknown \n",
|
|||
|
"0 0.0 0.0 0.0 \n",
|
|||
|
"1 0.0 1.0 0.0 \n",
|
|||
|
"2 0.0 0.0 0.0 \n",
|
|||
|
"3 0.0 1.0 0.0 \n",
|
|||
|
"4 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"136935 0.0 1.0 0.0 \n",
|
|||
|
"136936 0.0 0.0 0.0 \n",
|
|||
|
"136937 0.0 0.0 0.0 \n",
|
|||
|
"136938 0.0 0.0 0.0 \n",
|
|||
|
"136939 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
"[136940 rows x 30 columns]\n",
|
|||
|
"(136940, 30)\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn.base import BaseEstimator, TransformerMixin\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.preprocessing import OneHotEncoder\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor \n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"from sklearn.pipeline import make_pipeline\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"\n",
|
|||
|
"class JioMartFeatures(BaseEstimator, TransformerMixin): \n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" pass\n",
|
|||
|
"\n",
|
|||
|
" def fit(self, X, y=None):\n",
|
|||
|
" return self\n",
|
|||
|
"\n",
|
|||
|
" def transform(self, X, y=None):\n",
|
|||
|
" if 'state' in X.columns:\n",
|
|||
|
" X[\"city_latitude_per_state\"] = X[\"city_latitude\"] / X[\"state\"].nunique()\n",
|
|||
|
" return X\n",
|
|||
|
"\n",
|
|||
|
" def get_feature_names_out(self, features_in):\n",
|
|||
|
" return np.append(features_in, [\"city_latitude_per_state\"], axis=0) \n",
|
|||
|
"\n",
|
|||
|
"# Определите признаки для вашей задачи\n",
|
|||
|
"columns_to_drop = [\"date_time\", \"posted\", \"city\", \"state\", \"summary\", \"stats\", \"report_link\", \"duration\", \"text\"] # Столбцы, которые можно удалить\n",
|
|||
|
"num_columns = [\"city_latitude\", \"city_longitude\"] # Числовые столбцы\n",
|
|||
|
"cat_columns = [\"shape\"] # Категориальные столбцы\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование числовых признаков\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование категориальных признаков\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Формирование конвейера\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\" \n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"drop_columns = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\",\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Окончательный конвейер\n",
|
|||
|
"pipeline_end = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" (\"drop_columns\", drop_columns),\n",
|
|||
|
" (\"custom_features\", JioMartFeatures()), # Добавляем custom_features\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Загрузка данных\n",
|
|||
|
"df = pd.read_csv(\"nuforc_reports.csv\")\n",
|
|||
|
"\n",
|
|||
|
"# Создаем целевой признак\n",
|
|||
|
"average_city_latitude = df['city_latitude'].mean()\n",
|
|||
|
"df['above_average_city_latitude'] = (df['city_latitude'] > average_city_latitude).astype(int)\n",
|
|||
|
"\n",
|
|||
|
"# Подготовка данных\n",
|
|||
|
"X = df.drop('above_average_city_latitude', axis=1)\n",
|
|||
|
"y = df['above_average_city_latitude'].values.ravel()\n",
|
|||
|
"\n",
|
|||
|
"# Проверка наличия столбцов перед применением конвейера\n",
|
|||
|
"required_columns = set(num_columns + cat_columns + columns_to_drop)\n",
|
|||
|
"missing_columns = required_columns - set(X.columns)\n",
|
|||
|
"if missing_columns:\n",
|
|||
|
" raise KeyError(f\"Missing columns: {missing_columns}\")\n",
|
|||
|
"\n",
|
|||
|
"# Применение конвейера\n",
|
|||
|
"X_processed = pipeline_end.fit_transform(X)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод\n",
|
|||
|
"print(X_processed)\n",
|
|||
|
"print(X_processed.shape)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Формирование набора моделей для регрессии"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 51,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Random Forest: Mean Score = 0.9794675822991522, Standard Deviation = 0.016333217085689338\n",
|
|||
|
"Linear Regression: Mean Score = 0.5039253856797983, Standard Deviation = 0.030322793232352978\n",
|
|||
|
"Gradient Boosting: Mean Score = 0.9901727931253768, Standard Deviation = 0.008628774764973144\n",
|
|||
|
"Support Vector Regression: Mean Score = 0.8080621690604891, Standard Deviation = 0.04395269414319326\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.linear_model import LinearRegression\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor\n",
|
|||
|
"from sklearn.model_selection import cross_val_score\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.ensemble import GradientBoostingRegressor\n",
|
|||
|
"from sklearn.svm import SVR\n",
|
|||
|
"\n",
|
|||
|
"def train_multiple_models(X, y, models, cv=3):\n",
|
|||
|
" results = {}\n",
|
|||
|
" for model_name, model in models.items():\n",
|
|||
|
" # Создаем конвейер для каждой модели\n",
|
|||
|
" model_pipeline = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" (\"drop_columns\", drop_columns),\n",
|
|||
|
" (\"model\", model) # Используем текущую модель\n",
|
|||
|
" ]\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" # Обучаем модель и вычисляем кросс-валидацию\n",
|
|||
|
" scores = cross_val_score(model_pipeline, X, y, cv=cv, n_jobs=-1) # Используем все ядра процессора\n",
|
|||
|
" results[model_name] = {\n",
|
|||
|
" \"mean_score\": scores.mean(),\n",
|
|||
|
" \"std_dev\": scores.std()\n",
|
|||
|
" }\n",
|
|||
|
" \n",
|
|||
|
" return results\n",
|
|||
|
"\n",
|
|||
|
"# Определение моделей\n",
|
|||
|
"models = {\n",
|
|||
|
" \"Random Forest\": RandomForestRegressor(n_estimators=10), # Уменьшаем количество деревьев\n",
|
|||
|
" \"Linear Regression\": LinearRegression(),\n",
|
|||
|
" \"Gradient Boosting\": GradientBoostingRegressor(),\n",
|
|||
|
" \"Support Vector Regression\": SVR()\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Используем подвыборку данных\n",
|
|||
|
"sample_size = 1000 # Уменьшаем количество данных для обучения\n",
|
|||
|
"X_train_sample = X_train.sample(n=sample_size, random_state=42)\n",
|
|||
|
"y_train_sample = y_train.loc[X_train_sample.index] # Используем loc для индексации Series\n",
|
|||
|
"\n",
|
|||
|
"# Обучение моделей и вывод результатов\n",
|
|||
|
"results = train_multiple_models(X_train_sample, y_train_sample, models, cv=3) # Уменьшаем количество фолдов\n",
|
|||
|
"\n",
|
|||
|
"# Вывод результатов\n",
|
|||
|
"for model_name, scores in results.items():\n",
|
|||
|
" print(f\"{model_name}: Mean Score = {scores['mean_score']}, Standard Deviation = {scores['std_dev']}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Модель: Random Forest\n",
|
|||
|
"- **Mean Score**: 0.9897752006377067\n",
|
|||
|
"- **Standard Deviation**: 0.012886225390386691\n",
|
|||
|
"**Описание**:\n",
|
|||
|
"- Random Forest показала очень высокое среднее значение, близкое к 1, что указывает на ее высокую точность в предсказании. Стандартное отклонение также относительно низкое, что говорит о стабильности модели.\n",
|
|||
|
"\n",
|
|||
|
"#### Модель: Linear Regression\n",
|
|||
|
"- **Mean Score**: -1.439679711903671e+21\n",
|
|||
|
"- **Standard Deviation**: 1.9848730981021744e+21\n",
|
|||
|
"**Описание**:\n",
|
|||
|
"- Линейная регрессия показала очень низкое среднее значение с огромным отрицательным числом, что указывает на ее неэффективность в данной задаче. Стандартное отклонение также очень высокое, что говорит о нестабильности модели.\n",
|
|||
|
"\n",
|
|||
|
"#### Модель: Gradient Boosting\n",
|
|||
|
"- **Mean Score**: 0.990533312551943\n",
|
|||
|
"- **Standard Deviation**: 0.01338791677558754\n",
|
|||
|
"**Описание**:\n",
|
|||
|
"- Gradient Boosting показала практически идеальное среднее значение, близкое к 1, что указывает на ее высокую точность в предсказании. Стандартное отклонение относительно низкое, что говорит о стабильности модели.\n",
|
|||
|
"\n",
|
|||
|
"#### Модель: Support Vector Regression\n",
|
|||
|
"- **Mean Score**: 0.6408179773886161\n",
|
|||
|
"- **Standard Deviation**: 0.045968161125540155\n",
|
|||
|
"**Описание**:\n",
|
|||
|
"- Support Vector Regression показала среднее значение около 0.64, что указывает на ее умеренную точность в предсказании. Стандартное отклонение относительно низкое, что говорит о стабильности модели, но она все же уступает Random Forest и Gradient Boosting.\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"1. **Random Forest и Gradient Boosting** демонстрируют высокую точность и стабильность, что делает их наиболее подходящими моделями для данной задачи регрессии.\n",
|
|||
|
"2. **Linear Regression** неэффективна и нестабильна, что указывает на необходимость ее замены на более подходящую модель.\n",
|
|||
|
"3. **Support Vector Regression** показывает умеренную точность и стабильность, но уступает Random Forest и Gradient Boosting в эффективности."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обучение моделей на обучающем наборе данных и оценка на тестовом для регрессии"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 52,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: logistic\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\tumvu\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"MSE (train): 0.04693661457572659\n",
|
|||
|
"MSE (test): 0.04651672265225646\n",
|
|||
|
"MAE (train): 0.04693661457572659\n",
|
|||
|
"MAE (test): 0.04651672265225646\n",
|
|||
|
"R2 (train): 0.8090856146775022\n",
|
|||
|
"R2 (test): 0.810957714373292\n",
|
|||
|
"STD (train): 0.21150311767890387\n",
|
|||
|
"STD (test): 0.21060132280199356\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: ridge\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\tumvu\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"MSE (train): 0.04438987877902731\n",
|
|||
|
"MSE (test): 0.04465459325251935\n",
|
|||
|
"MAE (train): 0.04438987877902731\n",
|
|||
|
"MAE (test): 0.04465459325251935\n",
|
|||
|
"R2 (train): 0.8194444465532269\n",
|
|||
|
"R2 (test): 0.8185253411919435\n",
|
|||
|
"STD (train): 0.20595974713766416\n",
|
|||
|
"STD (test): 0.2065443307233859\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: decision_tree\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\tumvu\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"MSE (train): 0.0\n",
|
|||
|
"MSE (test): 0.0\n",
|
|||
|
"MAE (train): 0.0\n",
|
|||
|
"MAE (test): 0.0\n",
|
|||
|
"R2 (train): 1.0\n",
|
|||
|
"R2 (test): 1.0\n",
|
|||
|
"STD (train): 0.0\n",
|
|||
|
"STD (test): 0.0\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: knn\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\tumvu\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"MSE (train): 0.0024737111143566526\n",
|
|||
|
"MSE (test): 0.0035416970936176426\n",
|
|||
|
"MAE (train): 0.0024737111143566526\n",
|
|||
|
"MAE (test): 0.0035416970936176426\n",
|
|||
|
"R2 (train): 0.9899381955615719\n",
|
|||
|
"R2 (test): 0.9856066705606039\n",
|
|||
|
"STD (train): 0.04973611399311279\n",
|
|||
|
"STD (test): 0.05950515838317305\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: naive_bayes\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\tumvu\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"MSE (train): 0.12056375054768512\n",
|
|||
|
"MSE (test): 0.12202424419453775\n",
|
|||
|
"MAE (train): 0.12056375054768512\n",
|
|||
|
"MAE (test): 0.12202424419453775\n",
|
|||
|
"R2 (train): 0.5096077010230355\n",
|
|||
|
"R2 (test): 0.5040978661189495\n",
|
|||
|
"STD (train): 0.32835683006894384\n",
|
|||
|
"STD (test): 0.33012146600139913\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: gradient_boosting\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\tumvu\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"MSE (train): 0.0\n",
|
|||
|
"MSE (test): 0.0\n",
|
|||
|
"MAE (train): 0.0\n",
|
|||
|
"MAE (test): 0.0\n",
|
|||
|
"R2 (train): 1.0\n",
|
|||
|
"R2 (test): 1.0\n",
|
|||
|
"STD (train): 0.0\n",
|
|||
|
"STD (test): 0.0\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: random_forest\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\tumvu\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"MSE (train): 0.00040163575288447494\n",
|
|||
|
"MSE (test): 0.0005476851175697386\n",
|
|||
|
"MAE (train): 0.00040163575288447494\n",
|
|||
|
"MAE (test): 0.0005476851175697386\n",
|
|||
|
"R2 (train): 0.9983663490948678\n",
|
|||
|
"R2 (test): 0.997774227406279\n",
|
|||
|
"STD (train): 0.020036827134216634\n",
|
|||
|
"STD (test): 0.023396263773981758\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: mlp\n",
|
|||
|
"MSE (train): 0.0206659851029648\n",
|
|||
|
"MSE (test): 0.01971666423251059\n",
|
|||
|
"MAE (train): 0.0206659851029648\n",
|
|||
|
"MAE (test): 0.01971666423251059\n",
|
|||
|
"R2 (train): 0.9159412352450146\n",
|
|||
|
"R2 (test): 0.9198721866260421\n",
|
|||
|
"STD (train): 0.1430221622813633\n",
|
|||
|
"STD (test): 0.13976464450173487\n",
|
|||
|
"----------------------------------------\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\tumvu\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"\n",
|
|||
|
"# Проверка наличия необходимых переменных\n",
|
|||
|
"if 'class_models' not in locals():\n",
|
|||
|
" raise ValueError(\"class_models is not defined\")\n",
|
|||
|
"if 'X_train' not in locals() or 'X_test' not in locals() or 'y_train' not in locals() or 'y_test' not in locals():\n",
|
|||
|
" raise ValueError(\"Train/test data is not defined\")\n",
|
|||
|
"\n",
|
|||
|
"# Преобразуем y_train и y_test в одномерные массивы\n",
|
|||
|
"y_train = np.ravel(y_train) \n",
|
|||
|
"y_test = np.ravel(y_test) \n",
|
|||
|
"\n",
|
|||
|
"# Инициализация списка для хранения результатов\n",
|
|||
|
"results = []\n",
|
|||
|
"\n",
|
|||
|
"# Проход по моделям и оценка их качества\n",
|
|||
|
"for model_name in class_models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" \n",
|
|||
|
" # Извлечение модели из словаря\n",
|
|||
|
" model = class_models[model_name][\"model\"]\n",
|
|||
|
" \n",
|
|||
|
" # Создание пайплайна\n",
|
|||
|
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
|
|||
|
" \n",
|
|||
|
" # Обучение модели\n",
|
|||
|
" model_pipeline.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
" # Предсказание для обучающей и тестовой выборки\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train)\n",
|
|||
|
" y_test_predict = model_pipeline.predict(X_test)\n",
|
|||
|
"\n",
|
|||
|
" # Сохранение пайплайна и предсказаний\n",
|
|||
|
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
|||
|
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
|||
|
"\n",
|
|||
|
" # Вычисление метрик для регрессии\n",
|
|||
|
" class_models[model_name][\"MSE_train\"] = metrics.mean_squared_error(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"MSE_test\"] = metrics.mean_squared_error(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"MAE_train\"] = metrics.mean_absolute_error(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"MAE_test\"] = metrics.mean_absolute_error(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"R2_train\"] = metrics.r2_score(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_predict)\n",
|
|||
|
"\n",
|
|||
|
" # Дополнительные метрики\n",
|
|||
|
" class_models[model_name][\"STD_train\"] = np.std(y_train - y_train_predict)\n",
|
|||
|
" class_models[model_name][\"STD_test\"] = np.std(y_test - y_test_predict)\n",
|
|||
|
"\n",
|
|||
|
" # Вывод результатов для текущей модели\n",
|
|||
|
" print(f\"MSE (train): {class_models[model_name]['MSE_train']}\")\n",
|
|||
|
" print(f\"MSE (test): {class_models[model_name]['MSE_test']}\")\n",
|
|||
|
" print(f\"MAE (train): {class_models[model_name]['MAE_train']}\")\n",
|
|||
|
" print(f\"MAE (test): {class_models[model_name]['MAE_test']}\")\n",
|
|||
|
" print(f\"R2 (train): {class_models[model_name]['R2_train']}\")\n",
|
|||
|
" print(f\"R2 (test): {class_models[model_name]['R2_test']}\")\n",
|
|||
|
" print(f\"STD (train): {class_models[model_name]['STD_train']}\")\n",
|
|||
|
" print(f\"STD (test): {class_models[model_name]['STD_test']}\")\n",
|
|||
|
" print(\"-\" * 40) # Разделитель для разных моделей"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Пример использования обученной модели (конвейера регрессии) для предсказания"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 53,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: RandomForest\n",
|
|||
|
"MSE (train): 0.025664413833537104\n",
|
|||
|
"MSE (test): 0.12705862332487192\n",
|
|||
|
"MAE (train): 0.017621854354383706\n",
|
|||
|
"MAE (test): 0.06595286282214888\n",
|
|||
|
"R2 (train): 0.9991314892285292\n",
|
|||
|
"R2 (test): 0.9964919352072331\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Прогнозируемая цена: 25.070991060937757\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\tumvu\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\tumvu\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor \n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.preprocessing import OneHotEncoder\n",
|
|||
|
"\n",
|
|||
|
"# 1. Загрузка данных\n",
|
|||
|
"data = pd.read_csv(\"nuforc_reports.csv\") \n",
|
|||
|
"data = data.head(1000)\n",
|
|||
|
"\n",
|
|||
|
"# 2. Подготовка данных для прогноза\n",
|
|||
|
"average_city_latitude = data['city_latitude'].mean()\n",
|
|||
|
"data['above_average_city_latitude'] = (data['city_latitude'] > average_city_latitude).astype(int) \n",
|
|||
|
"\n",
|
|||
|
"# Удаляем строки с пропущенными значениями в столбце 'city_latitude'\n",
|
|||
|
"data = data.dropna(subset=['city_latitude'])\n",
|
|||
|
"\n",
|
|||
|
"# Предикторы и целевая переменная\n",
|
|||
|
"X = data.drop('above_average_city_latitude', axis=1) # Удаляем только 'above_average_city_latitude'\n",
|
|||
|
"y = data['city_latitude']\n",
|
|||
|
"\n",
|
|||
|
"# 3. Инициализация модели и пайплайна\n",
|
|||
|
"class_models = {\n",
|
|||
|
" \"RandomForest\": {\n",
|
|||
|
" \"model\": RandomForestRegressor(n_estimators=100, random_state=42),\n",
|
|||
|
" }\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Предобработка признаков\n",
|
|||
|
"num_columns = ['city_latitude']\n",
|
|||
|
"cat_columns = ['state', 'city']\n",
|
|||
|
"\n",
|
|||
|
"# Проверка наличия столбцов перед предобработкой\n",
|
|||
|
"required_columns = set(num_columns + cat_columns)\n",
|
|||
|
"missing_columns = required_columns - set(X.columns)\n",
|
|||
|
"if missing_columns:\n",
|
|||
|
" raise KeyError(f\"Missing columns: {missing_columns}\")\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование числовых признаков\n",
|
|||
|
"num_transformer = Pipeline(steps=[\n",
|
|||
|
" ('imputer', SimpleImputer(strategy='median')),\n",
|
|||
|
" ('scaler', StandardScaler())\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование категориальных признаков\n",
|
|||
|
"cat_transformer = Pipeline(steps=[\n",
|
|||
|
" ('imputer', SimpleImputer(strategy='constant', fill_value='unknown')),\n",
|
|||
|
" ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False, drop=\"first\"))\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Создание конвейера предобработки\n",
|
|||
|
"preprocessor = ColumnTransformer(\n",
|
|||
|
" transformers=[\n",
|
|||
|
" ('num', num_transformer, num_columns),\n",
|
|||
|
" ('cat', cat_transformer, cat_columns)\n",
|
|||
|
" ])\n",
|
|||
|
"\n",
|
|||
|
"# Создание конвейера модели\n",
|
|||
|
"pipeline_end = Pipeline(steps=[\n",
|
|||
|
" ('preprocessor', preprocessor),\n",
|
|||
|
" # ('model', model) # Модель добавляется в цикле\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"results = []\n",
|
|||
|
"\n",
|
|||
|
"# 4. Обучение модели и оценка\n",
|
|||
|
"for model_name in class_models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
"\n",
|
|||
|
" model = class_models[model_name][\"model\"]\n",
|
|||
|
" model_pipeline = Pipeline(steps=[\n",
|
|||
|
" ('preprocessor', preprocessor),\n",
|
|||
|
" ('model', model)\n",
|
|||
|
" ])\n",
|
|||
|
"\n",
|
|||
|
" # Разделение данных\n",
|
|||
|
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
|
|||
|
"\n",
|
|||
|
" # Обучение модели\n",
|
|||
|
" model_pipeline.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
" # Предсказание\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train)\n",
|
|||
|
" y_test_predict = model_pipeline.predict(X_test)\n",
|
|||
|
"\n",
|
|||
|
" # Сохранение результатов\n",
|
|||
|
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
|||
|
"\n",
|
|||
|
" # Вычисление метрик\n",
|
|||
|
" class_models[model_name][\"MSE_train\"] = metrics.mean_squared_error(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"MSE_test\"] = metrics.mean_squared_error(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"MAE_train\"] = metrics.mean_absolute_error(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"MAE_test\"] = metrics.mean_absolute_error(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"R2_train\"] = metrics.r2_score(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_predict)\n",
|
|||
|
"\n",
|
|||
|
" # Вывод результатов\n",
|
|||
|
" print(f\"MSE (train): {class_models[model_name]['MSE_train']}\")\n",
|
|||
|
" print(f\"MSE (test): {class_models[model_name]['MSE_test']}\")\n",
|
|||
|
" print(f\"MAE (train): {class_models[model_name]['MAE_train']}\")\n",
|
|||
|
" print(f\"MAE (test): {class_models[model_name]['MAE_test']}\")\n",
|
|||
|
" print(f\"R2 (train): {class_models[model_name]['R2_train']}\")\n",
|
|||
|
" print(f\"R2 (test): {class_models[model_name]['R2_test']}\")\n",
|
|||
|
" print(\"-\" * 40)\n",
|
|||
|
"\n",
|
|||
|
"# Прогнозирование цены для нового товара\n",
|
|||
|
"new_item_data = pd.DataFrame({\n",
|
|||
|
" 'state': ['Electronics'],\n",
|
|||
|
" 'city': ['Smartphones'], \n",
|
|||
|
" 'city_latitude': [0] # Добавляем столбец 'city_latitude' с нулевым значением\n",
|
|||
|
"})\n",
|
|||
|
"\n",
|
|||
|
"predicted_city_latitude = model_pipeline.predict(new_item_data)\n",
|
|||
|
"print(f\"Прогнозируемая цена: {predicted_city_latitude[0]}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Подбор гиперпараметров методом поиска по сетке"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 54,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\tumvu\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Fitting 3 folds for each of 36 candidates, totalling 108 fits\n",
|
|||
|
"Лучшие параметры: {'max_depth': None, 'min_samples_split': 2, 'n_estimators': 100}\n",
|
|||
|
"Лучший результат (MSE): 1.2759929698621533\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"\n",
|
|||
|
"# Удаление строк с пропущенными значениями (если необходимо)\n",
|
|||
|
"df = df.dropna()\n",
|
|||
|
"\n",
|
|||
|
"# Создание целевой переменной (city_latitude)\n",
|
|||
|
"target = df['city_latitude']\n",
|
|||
|
"\n",
|
|||
|
"# Удаление целевой переменной из исходных данных\n",
|
|||
|
"features = df.drop(columns=['city_latitude'])\n",
|
|||
|
"\n",
|
|||
|
"# Удаление столбцов, которые не будут использоваться (например, href и items)\n",
|
|||
|
"features = features.drop(columns=[\"date_time\", \"posted\", \"city\", \"state\", \"summary\", \"stats\", \"report_link\", \"duration\", \"text\"])\n",
|
|||
|
"\n",
|
|||
|
"# Определение столбцов для обработки\n",
|
|||
|
"num_columns = features.select_dtypes(include=['number']).columns\n",
|
|||
|
"cat_columns = features.select_dtypes(include=['object']).columns\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг числовых столбцов\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\") # Используем медиану для заполнения пропущенных значений в числовых столбцах\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг категориальных столбцов\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\") # Используем 'unknown' для заполнения пропущенных значений в категориальных столбцах\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Объединение препроцессинга\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\"\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Создание финального пайплайна\n",
|
|||
|
"pipeline_end = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Разделение данных на обучающую и тестовую выборки\n",
|
|||
|
"X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)\n",
|
|||
|
"\n",
|
|||
|
"# Применение пайплайна к данным\n",
|
|||
|
"X_train_processed = pipeline_end.fit_transform(X_train)\n",
|
|||
|
"X_test_processed = pipeline_end.transform(X_test)\n",
|
|||
|
"\n",
|
|||
|
"# 2. Создание и настройка модели случайного леса\n",
|
|||
|
"model = RandomForestRegressor()\n",
|
|||
|
"\n",
|
|||
|
"# Установка параметров для поиска по сетке\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
|
|||
|
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
|
|||
|
" 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# 3. Подбор гиперпараметров с помощью Grid Search\n",
|
|||
|
"grid_search = GridSearchCV(estimator=model, param_grid=param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели на тренировочных данных\n",
|
|||
|
"grid_search.fit(X_train_processed, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# 4. Результаты подбора гиперпараметров\n",
|
|||
|
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
|
|||
|
"print(\"Лучший результат (MSE):\", -grid_search.best_score_) # Меняем знак, так как берем отрицательное значение среднеквадратичной ошибки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обучение модели с новыми гиперпараметрами и сравнение новых и старых данных"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 59,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" summary city state \\\n",
|
|||
|
"0 Viewed some red lights in the sky appearing to... Visalia CA \n",
|
|||
|
"1 Look like 1 or 3 crafts from North traveling s... Cincinnati OH \n",
|
|||
|
"3 One red light moving switly west to east, beco... Knoxville TN \n",
|
|||
|
"5 I'm familiar with all the fakery and UFO sight... Fullerton CA \n",
|
|||
|
"6 I was driving up lakes mead towards the lake a... Las Vegas NV \n",
|
|||
|
"\n",
|
|||
|
" date_time shape duration \\\n",
|
|||
|
"0 2021-12-15T21:45:00 light 2 minutes \n",
|
|||
|
"1 2021-12-16T09:45:00 triangle 14 seconds \n",
|
|||
|
"3 2021-12-10T19:30:00 triangle 20-30 seconds \n",
|
|||
|
"5 2020-07-07T23:00:00 unknown 2 minutes \n",
|
|||
|
"6 2020-04-23T03:00:00 oval 10 minutes \n",
|
|||
|
"\n",
|
|||
|
" stats \\\n",
|
|||
|
"0 Occurred : 12/15/2021 21:45 (Entered as : 12/... \n",
|
|||
|
"1 Occurred : 12/16/2021 09:45 (Entered as : 12/... \n",
|
|||
|
"3 Occurred : 12/10/2021 19:30 (Entered as : 12/... \n",
|
|||
|
"5 Occurred : 7/7/2020 23:00 (Entered as : 07/07... \n",
|
|||
|
"6 Occurred : 4/23/2020 03:00 (Entered as : 4/23... \n",
|
|||
|
"\n",
|
|||
|
" report_link \\\n",
|
|||
|
"0 http://www.nuforc.org/webreports/165/S165881.html \n",
|
|||
|
"1 http://www.nuforc.org/webreports/165/S165888.html \n",
|
|||
|
"3 http://www.nuforc.org/webreports/165/S165825.html \n",
|
|||
|
"5 http://www.nuforc.org/webreports/157/S157444.html \n",
|
|||
|
"6 http://www.nuforc.org/webreports/155/S155608.html \n",
|
|||
|
"\n",
|
|||
|
" text posted \\\n",
|
|||
|
"0 Viewed some red lights in the sky appearing to... 2021-12-19T00:00:00 \n",
|
|||
|
"1 Look like 1 or 3 crafts from North traveling s... 2021-12-19T00:00:00 \n",
|
|||
|
"3 One red light moving switly west to east, beco... 2021-12-19T00:00:00 \n",
|
|||
|
"5 I'm familiar with all the fakery and UFO sight... 2020-07-09T00:00:00 \n",
|
|||
|
"6 I was driving up lakes mead towards the lake a... 2020-05-01T00:00:00 \n",
|
|||
|
"\n",
|
|||
|
" city_latitude city_longitude \n",
|
|||
|
"0 36.356650 -119.347937 \n",
|
|||
|
"1 39.174503 -84.481363 \n",
|
|||
|
"3 35.961561 -83.980115 \n",
|
|||
|
"5 33.877422 -117.924978 \n",
|
|||
|
"6 36.141246 -115.186592 \n",
|
|||
|
"Fitting 3 folds for each of 36 candidates, totalling 108 fits\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\tumvu\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1, 2, 3, 4] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Старые параметры: {'max_depth': 10, 'min_samples_split': 2, 'n_estimators': 50}\n",
|
|||
|
"Лучший результат (MSE) на старых параметрах: 0.6044726602932151\n",
|
|||
|
"\n",
|
|||
|
"Новые параметры: {'max_depth': 10, 'min_samples_split': 10, 'n_estimators': 200}\n",
|
|||
|
"Лучший результат (MSE) на новых параметрах: 4.113148481479761\n",
|
|||
|
"Среднеквадратическая ошибка (MSE) на тестовых данных: 0.14708677585880306\n",
|
|||
|
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 0.38351893807060305\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0kAAAHWCAYAAACi1sL/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd1gUxxvA8e/Rjy6igAqigopiw94iVrD9jCUae9ck9kRjYuwaMbFrotGY2GPvGruiogiKgr2hiAULvdfb3x8nF06KoMChzud57oHbnZ19b28P9r2ZnZFJkiQhCIIgCIIgCIIgAKCl6QAEQRAEQRAEQRCKEpEkCYIgCIIgCIIgZCCSJEEQBEEQBEEQhAxEkiQIgiAIgiAIgpCBSJIEQRAEQRAEQRAyEEmSIAiCIAiCIAhCBiJJEgRBEARBEARByEAkSYIgCIIgCIIgCBmIJEkQBEEQBEEQBCEDkSQJgiAIgiAI2Tpw4AD+/v6q53v27OHGjRuaC0gQCoFIkgShCAsMDGT48OGUL18eAwMDTE1Nady4MUuWLCEhIUHT4QmCIAifgGvXrjFmzBju3bvHhQsX+Oqrr4iJidF0WIJQoGSSJEmaDkIQhMwOHjzIF198gb6+Pv369cPZ2Znk5GS8vLzYuXMnAwYMYNWqVZoOUxAEQfjIvXr1ikaNGnH//n0AunTpws6dOzUclSAULJEkCUIR9PDhQ6pXr06ZMmU4efIkNjY2auvv37/PwYMHGTNmjIYiFARBED4lSUlJXL9+HUNDQ5ycnDQdjiAUONHdThCKoF9//ZXY2Fj++uuvTAkSgIODg1qCJJPJGDlyJJs2baJSpUoYGBhQu3Ztzpw5o7bdo0eP+Oabb6hUqRJyuZzixYvzxRdfEBQUpFZu7dq1yGQy1cPQ0JBq1aqxevVqtXIDBgzA2Ng4U3w7duxAJpPh6empttzHxwd3d3fMzMwwNDSkWbNmnDt3Tq3M9OnTkclkhIaGqi2/dOkSMpmMtWvXqu3f3t5erdzjx4+Ry+XIZLJMr+vQoUM0bdoUIyMjTExMaN++fa761acfjzNnzjB8+HCKFy+Oqakp/fr1IyIiIlP53Ozn6tWrDBgwQNWV0tramkGDBhEWFpZlDPb29mrvSfoj4zG2t7enQ4cOOb6WoKAgZDIZ8+fPz7TO2dkZV1dX1XNPT09kMhk7duzItr4334Np06ahpaXFiRMn1MoNGzYMPT09AgICcoxPJpMxffp0tWXz5s1DJpOpxZbT9tk9MsaZ8TgsWrSIsmXLIpfLadasGdevX89U7+3bt+nWrRsWFhYYGBhQp04d9u3bl2UMAwYMyHL/AwYMyFT20KFDNGvWDBMTE0xNTalbty7//POPar2rq2um1/3zzz+jpaWlVu7s2bN88cUX2NnZoa+vj62tLePGjcvULXf69OlUqVIFY2NjTE1NadCgAXv27FErk9u68vL5d3V1xdnZOVPZ+fPnZ/qsvu08Tj8v0+u/desWcrmcfv36qZXz8vJCW1ubiRMnZlsX5O6Y5CX+vXv30r59e0qVKoW+vj4VKlRg1qxZpKWlqW2b1bme/rfmXf525fX9ePO8unjxoupczSpOfX19ateujZOTU54+k4LwodLRdACCIGS2f/9+ypcvT6NGjXK9zenTp9m6dSujR49GX1+f5cuX4+7ujq+vr+qf+8WLFzl//jxffvklZcqUISgoiBUrVuDq6srNmzcxNDRUq3PRokVYWloSHR3N33//zdChQ7G3t6dVq1Z5fk0nT56kbdu21K5dW3UhvWbNGlq0aMHZs2epV69enuvMytSpU0lMTMy0fMOGDfTv3x83Nzd++eUX4uPjWbFiBU2aNOHKlSuZkq2sjBw5EnNzc6ZPn86dO3dYsWIFjx49Ul205WU/x44d48GDBwwcOBBra2tu3LjBqlWruHHjBhcuXMh0oQLQtGlThg0bBigvDOfMmfPuB6qATJ48mf379zN48GCuXbuGiYkJR44c4c8//2TWrFnUqFEjT/VFRkbi4eGRp21at26d6YJ5wYIFWSa069evJyYmhhEjRpCYmMiSJUto0aIF165dw8rKCoAbN27QuHFjSpcuzQ8//ICRkRHbtm3j888/Z+fOnXTu3DlTvfr6+mpfKgwZMiRTmbVr1zJo0CCqVq3Kjz/+iLm5OVeuXOHw4cP06tUry9e2Zs0aJk+ezIIFC9TKbN++nfj4eL7++muKFy+Or68vy5Yt48mTJ2zfvl1VLi4ujs6dO2Nvb09CQgJr166la9eueHt7qz6Dua2rqHBycmLWrFlMmDCBbt268b///Y+4uDgGDBhA5cqVmTlzZo7b5+aY5MXatWsxNjbm22+/xdjYmJMnTzJ16lSio6OZN29enuvLj79dufG2ZDLdu3wmBeGDJAmCUKRERUVJgNSpU6dcbwNIgHTp0iXVskePHkkGBgZS586dVcvi4+Mzbevt7S0B0vr161XL1qxZIwHSw4cPVcvu3r0rAdKvv/6qWta/f3/JyMgoU53bt2+XAOnUqVOSJEmSQqGQHB0dJTc3N0mhUKjFU65cOal169aqZdOmTZMA6dWrV2p1Xrx4UQKkNWvWqO2/bNmyqufXr1+XtLS0pLZt26rFHxMTI5mbm0tDhw5Vq/P58+eSmZlZpuVvSj8etWvXlpKTk1XLf/31VwmQ9u7dm+f9ZPVebN68WQKkM2fOZFpXunRpaeDAgarnp06dUjvGkiRJZcuWldq3b5/ja3n48KEESPPmzcu0rmrVqlKzZs0y7WP79u3Z1vfmeyBJknTt2jVJT09PGjJkiBQRESGVLl1aqlOnjpSSkpJjbJKkPJenTZumev79999LJUuWlGrXrq0WW07bjxgxItPy9u3bq8WZfhzkcrn05MkT1XIfHx8JkMaNG6da1rJlS6latWpSYmKiaplCoZAaNWokOTo6ZtpXr169JGNjY7VlRkZGUv/+/VXPIyMjJRMTE6l+/fpSQkKCWtmMn5FmzZqpXvfBgwclHR0d6bvvvsu0z6zOJw8PD0kmk0mPHj3KtC7dy5cvJUCaP39+nuvK7ec//XVUrVo1U9l58+Zl+lvztvM4q3M/LS1NatKkiWRlZSWFhoZKI0aMkHR0dKSLFy9mW092sjomeYk/q+M3fPhwydDQUO0ckslk0tSpU9XKvfm3Ny9/U/L6fmT8PP37778SILm7u0tvXhq+72dSED5UorudIBQx0dHRAJiYmORpu4YNG1K7dm3Vczs7Ozp16sSRI0dU3TzkcrlqfUpKCmFhYTg4OGBubs7ly5cz1RkREUFoaCgPHjxg0aJFaGtr06xZs0zlQkND1R5vjnrk7+/PvXv36NWrF2FhYapycXFxtGzZkjNnzqBQKNS2CQ8PV6szKirqrcfgxx9/xMXFhS+++EJt+bFjx4iMjKRnz55qdWpra1O/fn1OnTr11rpB2WVMV1dX9fzrr79GR0eHf//9N8/7yfheJCYmEhoaSoMGDQCyfC+Sk5PR19d/a4wpKSmEhoYSFhZGampqtuXi4+MzvW9vdgdKFxMTQ2hoKJGRkW/dPyi77c2YMYPVq1fj5uZGaGgo69atQ0cnb50Xnj59yrJly5gyZUqW3Yjyw+eff07p0qVVz+vVq0f9+vVV72l4eDgnT56ke/fuquOQfnzd3Ny4d+8eT58+VaszMTERAwODHPd77NgxYmJi+OGHHzKVzaoV0dfXl+7du9O1a9csWyMynk9xcXGEhobSqFEjJEniypUramXTz5HAwEDmzp2LlpYWjRs3fqe64O2f/3RpaWmZysbHx2dZNrfncTotLS3Wrl1LbGwsbdu2Zfny5fz444/UqVPnrdtm3F92xyQv8Wc8funnTNOmTYmPj+f27duqdSVLluTJkyc5xvUuf7ty+36kkySJH3/8ka5du1K/fv0cyxbGZ1IQigrR3U4QihhTU1OAPA+v6ujomGlZxYoViY+P59WrV1hbW5OQkICHhwdr1qzh6dOnSBnGbckqCXFxcVH9rq+vz2+//Zap+0lcXBwlSpTIMbZ79+4B0L9//2zLREVFUaxYMdXzSpUq5Vjnm7y8vNi
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1000x500 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor\n",
|
|||
|
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"# Загрузка датасета\n",
|
|||
|
"df = pd.read_csv(\"nuforc_reports.csv\").head(100).dropna()\n",
|
|||
|
"\n",
|
|||
|
"# Вывод первых строк для проверки структуры\n",
|
|||
|
"print(df.head())\n",
|
|||
|
"\n",
|
|||
|
"# Целевая переменная\n",
|
|||
|
"target = df['city_latitude']\n",
|
|||
|
"\n",
|
|||
|
"# Удаление целевой переменной из признаков\n",
|
|||
|
"features = df.drop(columns=['summary', 'stats', 'report_link', 'posted', \"duration\"])\n",
|
|||
|
"\n",
|
|||
|
"# Определение столбцов для обработки\n",
|
|||
|
"num_columns = features.select_dtypes(include=['number']).columns\n",
|
|||
|
"cat_columns = features.select_dtypes(include=['object']).columns\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг числовых столбцов\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline([\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг категориальных столбцов\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline([\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Объединение препроцессинга\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\"\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Создание финального пайплайна\n",
|
|||
|
"pipeline_end = Pipeline([\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Разделение данных на обучающую и тестовую выборки\n",
|
|||
|
"X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)\n",
|
|||
|
"\n",
|
|||
|
"# Применение пайплайна к данным\n",
|
|||
|
"X_train_processed = pipeline_end.fit_transform(X_train)\n",
|
|||
|
"X_test_processed = pipeline_end.transform(X_test)\n",
|
|||
|
"\n",
|
|||
|
"# 1. Настройка параметров для старых значений\n",
|
|||
|
"old_param_grid = {\n",
|
|||
|
" 'n_estimators': [50, 100, 200],\n",
|
|||
|
" 'max_depth': [None, 10, 20, 30],\n",
|
|||
|
" 'min_samples_split': [2, 5, 10]\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Подбор гиперпараметров с помощью Grid Search для старых параметров\n",
|
|||
|
"old_grid_search = GridSearchCV(estimator=RandomForestRegressor(),\n",
|
|||
|
" param_grid=old_param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели на тренировочных данных\n",
|
|||
|
"old_grid_search.fit(X_train_processed, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# Результаты подбора для старых параметров\n",
|
|||
|
"old_best_params = old_grid_search.best_params_\n",
|
|||
|
"old_best_mse = -old_grid_search.best_score_\n",
|
|||
|
"\n",
|
|||
|
"# 2. Настройка параметров для новых значений\n",
|
|||
|
"new_param_grid = {\n",
|
|||
|
" 'n_estimators': [200],\n",
|
|||
|
" 'max_depth': [10],\n",
|
|||
|
" 'min_samples_split': [10]\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Подбор гиперпараметров с помощью Grid Search для новых параметров\n",
|
|||
|
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(),\n",
|
|||
|
" param_grid=new_param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', cv=2)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели на тренировочных данных\n",
|
|||
|
"new_grid_search.fit(X_train_processed, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# Результаты подбора для новых параметров\n",
|
|||
|
"new_best_params = new_grid_search.best_params_\n",
|
|||
|
"new_best_mse = -new_grid_search.best_score_\n",
|
|||
|
"\n",
|
|||
|
"# 5. Обучение модели с лучшими параметрами для новых значений\n",
|
|||
|
"model_best = RandomForestRegressor(**new_best_params)\n",
|
|||
|
"model_best.fit(X_train_processed, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# Прогнозирование на тестовой выборке\n",
|
|||
|
"y_pred = model_best.predict(X_test_processed)\n",
|
|||
|
"\n",
|
|||
|
"# Оценка производительности модели\n",
|
|||
|
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
|
|||
|
"rmse = np.sqrt(mse)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод результатов\n",
|
|||
|
"print(\"Старые параметры:\", old_best_params)\n",
|
|||
|
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
|
|||
|
"print(\"\\nНовые параметры:\", new_best_params)\n",
|
|||
|
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
|
|||
|
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
|
|||
|
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели с лучшими параметрами для старых значений\n",
|
|||
|
"model_old = RandomForestRegressor(**old_best_params)\n",
|
|||
|
"model_old.fit(X_train_processed, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# Прогнозирование на тестовой выборке для старых параметров\n",
|
|||
|
"y_pred_old = model_old.predict(X_test_processed)\n",
|
|||
|
"\n",
|
|||
|
"# Визуализация ошибок\n",
|
|||
|
"plt.figure(figsize=(10, 5))\n",
|
|||
|
"plt.plot(y_test.values, label='Реальные значения', marker='o', linestyle='-', color='black')\n",
|
|||
|
"plt.plot(y_pred_old, label='Предсказанные значения (старые параметры)', marker='x', linestyle='--', color='blue')\n",
|
|||
|
"plt.plot(y_pred, label='Предсказанные значения (новые параметры)', marker='s', linestyle='--', color='orange')\n",
|
|||
|
"plt.xlabel('Объекты')\n",
|
|||
|
"plt.ylabel('Цена')\n",
|
|||
|
"plt.title('Сравнение реальных и предсказанных значений')\n",
|
|||
|
"plt.legend()\n",
|
|||
|
"plt.show()\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.12.0"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|