4155 lines
490 KiB
Plaintext
4155 lines
490 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Лабораторная работа №4\n",
|
|||
|
"\n",
|
|||
|
"*Вариант задания:* Товары Jio Mart (вариант - 23) "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Выбор бизнес-целей \n",
|
|||
|
"Для датасета недвижимости предлагаются две бизнес-цели:\n",
|
|||
|
"\n",
|
|||
|
"### Задача классификации:\n",
|
|||
|
"*Цель*: Классифицировать товары в разные категории, например, \"Дешевый\", \"Средний\" или \"Дорогой\", на основе цены и других характеристик товара.\n",
|
|||
|
"\n",
|
|||
|
"*Применение*: Полезно для определения целевой аудитории для разных типов товаров, создания маркетинговых кампаний и анализа рыночных сегментов.\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"### Задача регрессии:\n",
|
|||
|
"*Цель*: Предсказать цену товара (price) на основе других характеристик.\n",
|
|||
|
"\n",
|
|||
|
"*Применение*: Эта задача полезна для оценки рыночной стоимости товаров в интернет-магазинах и онлайн-платформах, например, для прогнозирования цены новых или подержанных товаров на основе характеристик."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Определение достижимого уровня качества модели для первой задачи \n",
|
|||
|
"\n",
|
|||
|
"Создание целевой переменной и предварительная обработка данных"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Index(['category', 'sub_category', 'href', 'items', 'price'], dtype='object')\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"from sklearn import set_config\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
|
|||
|
"from sklearn import linear_model, tree, neighbors, naive_bayes, ensemble, neural_network\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import warnings\n",
|
|||
|
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
|||
|
"df = pd.read_csv(\"..//static//csv//jio_mart_items.csv\")\n",
|
|||
|
"print(df.columns)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Среднее значение поля 'price': 1991.6325132793531\n",
|
|||
|
" category sub_category \\\n",
|
|||
|
"0 Groceries Fruits & Vegetables \n",
|
|||
|
"1 Groceries Fruits & Vegetables \n",
|
|||
|
"2 Groceries Fruits & Vegetables \n",
|
|||
|
"3 Groceries Fruits & Vegetables \n",
|
|||
|
"4 Groceries Fruits & Vegetables \n",
|
|||
|
"\n",
|
|||
|
" href \\\n",
|
|||
|
"0 https://www.jiomart.com/c/groceries/fruits-veg... \n",
|
|||
|
"1 https://www.jiomart.com/c/groceries/fruits-veg... \n",
|
|||
|
"2 https://www.jiomart.com/c/groceries/fruits-veg... \n",
|
|||
|
"3 https://www.jiomart.com/c/groceries/fruits-veg... \n",
|
|||
|
"4 https://www.jiomart.com/c/groceries/fruits-veg... \n",
|
|||
|
"\n",
|
|||
|
" items price \\\n",
|
|||
|
"0 Fresh Dates (Pack) (Approx 450 g - 500 g) 109.0 \n",
|
|||
|
"1 Tender Coconut Cling Wrapped (1 pc) (Approx 90... 49.0 \n",
|
|||
|
"2 Mosambi 1 kg 69.0 \n",
|
|||
|
"3 Orange Imported 1 kg 125.0 \n",
|
|||
|
"4 Banana Robusta 6 pcs (Box) (Approx 800 g - 110... 44.0 \n",
|
|||
|
"\n",
|
|||
|
" above_average_price \n",
|
|||
|
"0 0 \n",
|
|||
|
"1 0 \n",
|
|||
|
"2 0 \n",
|
|||
|
"3 0 \n",
|
|||
|
"4 0 \n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Установим параметры для вывода\n",
|
|||
|
"set_config(transform_output=\"pandas\")\n",
|
|||
|
"\n",
|
|||
|
"# Рассчитываем среднее значение цены\n",
|
|||
|
"average_price = df['price'].mean()\n",
|
|||
|
"print(f\"Среднее значение поля 'price': {average_price}\")\n",
|
|||
|
"\n",
|
|||
|
"# Создаем новую переменную, указывающую, превышает ли цена среднюю цену\n",
|
|||
|
"df['above_average_price'] = (df['price'] > average_price).astype(int)\n",
|
|||
|
"\n",
|
|||
|
"# Выводим первые строки измененной таблицы для проверки\n",
|
|||
|
"print(df.head())"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
|
|||
|
"\n",
|
|||
|
"Целевой признак -- above_average_price"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"X_train shape: (129850, 4)\n",
|
|||
|
"y_train shape: (129850,)\n",
|
|||
|
"X_test shape: (32463, 4)\n",
|
|||
|
"y_test shape: (32463,)\n",
|
|||
|
"X_train:\n",
|
|||
|
" category sub_category \\\n",
|
|||
|
"131952 Fashion Girls \n",
|
|||
|
"106351 Home & Kitchen Power & Hand Tools \n",
|
|||
|
"141229 Electronics Cameras \n",
|
|||
|
"46383 Home & Kitchen Kitchenware \n",
|
|||
|
"123357 Fashion Women \n",
|
|||
|
"\n",
|
|||
|
" href price \n",
|
|||
|
"131952 https://www.jiomart.com/c/fashion/girls/watche... 299.0 \n",
|
|||
|
"106351 https://www.jiomart.com/c/groceries/home-kitch... 449.0 \n",
|
|||
|
"141229 https://www.jiomart.com/c/electronics/cameras/... 1358.0 \n",
|
|||
|
"46383 https://www.jiomart.com/c/groceries/home-kitch... 529.0 \n",
|
|||
|
"123357 https://www.jiomart.com/c/fashion/women/night-... 599.0 \n",
|
|||
|
"y_train:\n",
|
|||
|
" 131952 0\n",
|
|||
|
"106351 0\n",
|
|||
|
"141229 0\n",
|
|||
|
"46383 0\n",
|
|||
|
"123357 0\n",
|
|||
|
"Name: above_average_price, dtype: int64\n",
|
|||
|
"X_test:\n",
|
|||
|
" category sub_category \\\n",
|
|||
|
"112252 Fashion Men \n",
|
|||
|
"147122 Electronics Accessories \n",
|
|||
|
"27887 Groceries Home Care \n",
|
|||
|
"119606 Fashion Women \n",
|
|||
|
"94731 Home & Kitchen Mops, Brushes & Scrubs \n",
|
|||
|
"\n",
|
|||
|
" href price \n",
|
|||
|
"112252 https://www.jiomart.com/c/fashion/men/fashion-... 449.0 \n",
|
|||
|
"147122 https://www.jiomart.com/c/electronics/accessor... 4899.0 \n",
|
|||
|
"27887 https://www.jiomart.com/c/groceries/home-care/... 891.0 \n",
|
|||
|
"119606 https://www.jiomart.com/c/fashion/women/bags-b... 920.0 \n",
|
|||
|
"94731 https://www.jiomart.com/c/groceries/home-kitch... 399.0 \n",
|
|||
|
"y_test:\n",
|
|||
|
" 112252 0\n",
|
|||
|
"147122 1\n",
|
|||
|
"27887 0\n",
|
|||
|
"119606 0\n",
|
|||
|
"94731 0\n",
|
|||
|
"Name: above_average_price, dtype: int64\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Разделение набора данных на обучающую и тестовую выборки (80/20)\n",
|
|||
|
"random_state = 42\n",
|
|||
|
"X_train, X_test, y_train, y_test = train_test_split(\n",
|
|||
|
" df.drop(columns=['above_average_price', 'items']), # Исключаем столбец 'items'\n",
|
|||
|
" df['above_average_price'],\n",
|
|||
|
" stratify=df['above_average_price'],\n",
|
|||
|
" test_size=0.20,\n",
|
|||
|
" random_state=random_state\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Вывод размеров выборок\n",
|
|||
|
"print(\"X_train shape:\", X_train.shape)\n",
|
|||
|
"print(\"y_train shape:\", y_train.shape)\n",
|
|||
|
"print(\"X_test shape:\", X_test.shape)\n",
|
|||
|
"print(\"y_test shape:\", y_test.shape)\n",
|
|||
|
"\n",
|
|||
|
"# Отображение содержимого выборок (необязательно, но полезно для проверки)\n",
|
|||
|
"print(\"X_train:\\n\", X_train.head())\n",
|
|||
|
"print(\"y_train:\\n\", y_train.head())\n",
|
|||
|
"print(\"X_test:\\n\", X_test.head())\n",
|
|||
|
"print(\"y_test:\\n\", y_test.head())"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Формирование конвейера для классификации данных\n",
|
|||
|
"\n",
|
|||
|
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
|
|||
|
"\n",
|
|||
|
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
|
|||
|
"\n",
|
|||
|
"features_preprocessing -- трансформер для предобработки признаков\n",
|
|||
|
"\n",
|
|||
|
"drop_columns -- трансформер для удаления колонок\n",
|
|||
|
"\n",
|
|||
|
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# Определение столбцов для обработки\n",
|
|||
|
"columns_to_drop = [\"href\"] # Столбцы, которые можно удалить\n",
|
|||
|
"num_columns = [\"price\"] # Числовые столбцы\n",
|
|||
|
"cat_columns = [\"category\", \"sub_category\"] # Категориальные столбцы\n",
|
|||
|
"\n",
|
|||
|
"# Проверка наличия столбцов перед удалением\n",
|
|||
|
"columns_to_drop = [col for col in columns_to_drop if col in X_train.columns]\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг числовых столбцов\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг категориальных столбцов\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Объединение препроцессинга\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\"\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Удаление ненужных столбцов\n",
|
|||
|
"drop_columns = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\",\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Создание финального пайплайна\n",
|
|||
|
"pipeline_end = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" (\"drop_columns\", drop_columns),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Обучение пайплайна на обучающих данных\n",
|
|||
|
"pipeline_end.fit(X_train)\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование тестовых данных с использованием обученного пайплайна\n",
|
|||
|
"X_test_transformed = pipeline_end.transform(X_test)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"__Демонстрация работы конвейера__"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" price category_Electronics category_Fashion category_Groceries \\\n",
|
|||
|
"131952 -0.102874 0.0 1.0 0.0 \n",
|
|||
|
"106351 -0.093710 0.0 0.0 0.0 \n",
|
|||
|
"141229 -0.038173 1.0 0.0 0.0 \n",
|
|||
|
"46383 -0.088822 0.0 0.0 0.0 \n",
|
|||
|
"123357 -0.084545 0.0 1.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" category_Home & Kitchen category_Jewellery sub_category_Apparel \\\n",
|
|||
|
"131952 0.0 0.0 0.0 \n",
|
|||
|
"106351 1.0 0.0 0.0 \n",
|
|||
|
"141229 0.0 0.0 0.0 \n",
|
|||
|
"46383 1.0 0.0 0.0 \n",
|
|||
|
"123357 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Auto Care sub_category_Ayush \\\n",
|
|||
|
"131952 0.0 0.0 \n",
|
|||
|
"106351 0.0 0.0 \n",
|
|||
|
"141229 0.0 0.0 \n",
|
|||
|
"46383 0.0 0.0 \n",
|
|||
|
"123357 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Bags & Travel Luggage ... sub_category_Smart Devices \\\n",
|
|||
|
"131952 0.0 ... 0.0 \n",
|
|||
|
"106351 0.0 ... 0.0 \n",
|
|||
|
"141229 0.0 ... 0.0 \n",
|
|||
|
"46383 0.0 ... 0.0 \n",
|
|||
|
"123357 0.0 ... 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Snacks & Branded Foods sub_category_Staples \\\n",
|
|||
|
"131952 0.0 0.0 \n",
|
|||
|
"106351 0.0 0.0 \n",
|
|||
|
"141229 0.0 0.0 \n",
|
|||
|
"46383 0.0 0.0 \n",
|
|||
|
"123357 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Stationery sub_category_TV & Speaker \\\n",
|
|||
|
"131952 0.0 0.0 \n",
|
|||
|
"106351 0.0 0.0 \n",
|
|||
|
"141229 0.0 0.0 \n",
|
|||
|
"46383 0.0 0.0 \n",
|
|||
|
"123357 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Tools & Appliances sub_category_Toys, Games & Fitness \\\n",
|
|||
|
"131952 0.0 0.0 \n",
|
|||
|
"106351 0.0 0.0 \n",
|
|||
|
"141229 0.0 0.0 \n",
|
|||
|
"46383 0.0 0.0 \n",
|
|||
|
"123357 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Treatments sub_category_Wellness sub_category_Women \n",
|
|||
|
"131952 0.0 0.0 0.0 \n",
|
|||
|
"106351 0.0 0.0 0.0 \n",
|
|||
|
"141229 0.0 0.0 0.0 \n",
|
|||
|
"46383 0.0 0.0 0.0 \n",
|
|||
|
"123357 0.0 0.0 1.0 \n",
|
|||
|
"\n",
|
|||
|
"[5 rows x 75 columns]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Вывод первых строк обработанных данных\n",
|
|||
|
"print(preprocessed_df.head())"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Формирование набора моделей для классификации\n",
|
|||
|
"\n",
|
|||
|
"logistic -- логистическая регрессия\n",
|
|||
|
"\n",
|
|||
|
"ridge -- гребневая регрессия\n",
|
|||
|
"\n",
|
|||
|
"decision_tree -- дерево решений\n",
|
|||
|
"\n",
|
|||
|
"knn -- k-ближайших соседей\n",
|
|||
|
"\n",
|
|||
|
"naive_bayes -- наивный Байесовский классификатор\n",
|
|||
|
"\n",
|
|||
|
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
|
|||
|
"\n",
|
|||
|
"random_forest -- метод случайного леса (набор деревьев решений)\n",
|
|||
|
"\n",
|
|||
|
"mlp -- многослойный персептрон (нейронная сеть)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class_models = {\n",
|
|||
|
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
|
|||
|
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
|
|||
|
" \"decision_tree\": {\n",
|
|||
|
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=42)\n",
|
|||
|
" },\n",
|
|||
|
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
|
|||
|
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
|
|||
|
" \"gradient_boosting\": {\n",
|
|||
|
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
|
|||
|
" },\n",
|
|||
|
" \"random_forest\": {\n",
|
|||
|
" \"model\": ensemble.RandomForestClassifier(\n",
|
|||
|
" max_depth=11, class_weight=\"balanced\", random_state=42\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" \"mlp\": {\n",
|
|||
|
" \"model\": neural_network.MLPClassifier(\n",
|
|||
|
" hidden_layer_sizes=(7,),\n",
|
|||
|
" max_iter=500,\n",
|
|||
|
" early_stopping=True,\n",
|
|||
|
" random_state=42,\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
"}"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Обучение моделей и оценка их качества"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: logistic\n",
|
|||
|
"Model: ridge\n",
|
|||
|
"Model: decision_tree\n",
|
|||
|
"Model: knn\n",
|
|||
|
"Model: naive_bayes\n",
|
|||
|
"Model: gradient_boosting\n",
|
|||
|
"Model: random_forest\n",
|
|||
|
"Model: mlp\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for model_name in class_models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" model = class_models[model_name][\"model\"]\n",
|
|||
|
"\n",
|
|||
|
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
|
|||
|
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
|
|||
|
"\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train)\n",
|
|||
|
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
|
|||
|
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
|
|||
|
"\n",
|
|||
|
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
|||
|
" class_models[model_name][\"probs\"] = y_test_probs\n",
|
|||
|
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
|||
|
"\n",
|
|||
|
" # Оценка метрик\n",
|
|||
|
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
|
|||
|
" y_test, y_test_probs\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: logistic\n",
|
|||
|
"Precision (train): 0.9964\n",
|
|||
|
"Precision (test): 0.9964\n",
|
|||
|
"Recall (train): 0.9255\n",
|
|||
|
"Recall (test): 0.9228\n",
|
|||
|
"Accuracy (train): 0.9905\n",
|
|||
|
"Accuracy (test): 0.9902\n",
|
|||
|
"ROC AUC (test): 0.9998\n",
|
|||
|
"F1 (train): 0.9597\n",
|
|||
|
"F1 (test): 0.9582\n",
|
|||
|
"MCC (test): 0.9536\n",
|
|||
|
"Cohen's Kappa (test): 0.9527\n",
|
|||
|
"Confusion Matrix:\n",
|
|||
|
"[[28498 13]\n",
|
|||
|
" [ 305 3647]]\n",
|
|||
|
"\n",
|
|||
|
"Model: ridge\n",
|
|||
|
"Precision (train): 0.8862\n",
|
|||
|
"Precision (test): 0.8873\n",
|
|||
|
"Recall (train): 0.9999\n",
|
|||
|
"Recall (test): 1.0000\n",
|
|||
|
"Accuracy (train): 0.9844\n",
|
|||
|
"Accuracy (test): 0.9845\n",
|
|||
|
"ROC AUC (test): 0.9998\n",
|
|||
|
"F1 (train): 0.9396\n",
|
|||
|
"F1 (test): 0.9403\n",
|
|||
|
"MCC (test): 0.9336\n",
|
|||
|
"Cohen's Kappa (test): 0.9314\n",
|
|||
|
"Confusion Matrix:\n",
|
|||
|
"[[28009 502]\n",
|
|||
|
" [ 0 3952]]\n",
|
|||
|
"\n",
|
|||
|
"Model: decision_tree\n",
|
|||
|
"Precision (train): 1.0000\n",
|
|||
|
"Precision (test): 1.0000\n",
|
|||
|
"Recall (train): 1.0000\n",
|
|||
|
"Recall (test): 1.0000\n",
|
|||
|
"Accuracy (train): 1.0000\n",
|
|||
|
"Accuracy (test): 1.0000\n",
|
|||
|
"ROC AUC (test): 1.0000\n",
|
|||
|
"F1 (train): 1.0000\n",
|
|||
|
"F1 (test): 1.0000\n",
|
|||
|
"MCC (test): 1.0000\n",
|
|||
|
"Cohen's Kappa (test): 1.0000\n",
|
|||
|
"Confusion Matrix:\n",
|
|||
|
"[[28511 0]\n",
|
|||
|
" [ 0 3952]]\n",
|
|||
|
"\n",
|
|||
|
"Model: knn\n",
|
|||
|
"Precision (train): 0.9981\n",
|
|||
|
"Precision (test): 0.9972\n",
|
|||
|
"Recall (train): 0.9991\n",
|
|||
|
"Recall (test): 0.9987\n",
|
|||
|
"Accuracy (train): 0.9997\n",
|
|||
|
"Accuracy (test): 0.9995\n",
|
|||
|
"ROC AUC (test): 0.9999\n",
|
|||
|
"F1 (train): 0.9986\n",
|
|||
|
"F1 (test): 0.9980\n",
|
|||
|
"MCC (test): 0.9977\n",
|
|||
|
"Cohen's Kappa (test): 0.9977\n",
|
|||
|
"Confusion Matrix:\n",
|
|||
|
"[[28500 11]\n",
|
|||
|
" [ 5 3947]]\n",
|
|||
|
"\n",
|
|||
|
"Model: naive_bayes\n",
|
|||
|
"Precision (train): 0.1628\n",
|
|||
|
"Precision (test): 0.1643\n",
|
|||
|
"Recall (train): 0.9698\n",
|
|||
|
"Recall (test): 0.9742\n",
|
|||
|
"Accuracy (train): 0.3894\n",
|
|||
|
"Accuracy (test): 0.3938\n",
|
|||
|
"ROC AUC (test): 0.7510\n",
|
|||
|
"F1 (train): 0.2789\n",
|
|||
|
"F1 (test): 0.2812\n",
|
|||
|
"MCC (test): 0.2098\n",
|
|||
|
"Cohen's Kappa (test): 0.0921\n",
|
|||
|
"Confusion Matrix:\n",
|
|||
|
"[[ 8934 19577]\n",
|
|||
|
" [ 102 3850]]\n",
|
|||
|
"\n",
|
|||
|
"Model: gradient_boosting\n",
|
|||
|
"Precision (train): 1.0000\n",
|
|||
|
"Precision (test): 1.0000\n",
|
|||
|
"Recall (train): 1.0000\n",
|
|||
|
"Recall (test): 1.0000\n",
|
|||
|
"Accuracy (train): 1.0000\n",
|
|||
|
"Accuracy (test): 1.0000\n",
|
|||
|
"ROC AUC (test): 1.0000\n",
|
|||
|
"F1 (train): 1.0000\n",
|
|||
|
"F1 (test): 1.0000\n",
|
|||
|
"MCC (test): 1.0000\n",
|
|||
|
"Cohen's Kappa (test): 1.0000\n",
|
|||
|
"Confusion Matrix:\n",
|
|||
|
"[[28511 0]\n",
|
|||
|
" [ 0 3952]]\n",
|
|||
|
"\n",
|
|||
|
"Model: random_forest\n",
|
|||
|
"Precision (train): 1.0000\n",
|
|||
|
"Precision (test): 1.0000\n",
|
|||
|
"Recall (train): 1.0000\n",
|
|||
|
"Recall (test): 1.0000\n",
|
|||
|
"Accuracy (train): 1.0000\n",
|
|||
|
"Accuracy (test): 1.0000\n",
|
|||
|
"ROC AUC (test): 1.0000\n",
|
|||
|
"F1 (train): 1.0000\n",
|
|||
|
"F1 (test): 1.0000\n",
|
|||
|
"MCC (test): 1.0000\n",
|
|||
|
"Cohen's Kappa (test): 1.0000\n",
|
|||
|
"Confusion Matrix:\n",
|
|||
|
"[[28511 0]\n",
|
|||
|
" [ 0 3952]]\n",
|
|||
|
"\n",
|
|||
|
"Model: mlp\n",
|
|||
|
"Precision (train): 0.9957\n",
|
|||
|
"Precision (test): 0.9945\n",
|
|||
|
"Recall (train): 0.9996\n",
|
|||
|
"Recall (test): 0.9997\n",
|
|||
|
"Accuracy (train): 0.9994\n",
|
|||
|
"Accuracy (test): 0.9993\n",
|
|||
|
"ROC AUC (test): 1.0000\n",
|
|||
|
"F1 (train): 0.9977\n",
|
|||
|
"F1 (test): 0.9971\n",
|
|||
|
"MCC (test): 0.9967\n",
|
|||
|
"Cohen's Kappa (test): 0.9967\n",
|
|||
|
"Confusion Matrix:\n",
|
|||
|
"[[28489 22]\n",
|
|||
|
" [ 1 3951]]\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for model_name, results in class_models.items():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" print(f\"Precision (train): {results['Precision_train']:.4f}\")\n",
|
|||
|
" print(f\"Precision (test): {results['Precision_test']:.4f}\")\n",
|
|||
|
" print(f\"Recall (train): {results['Recall_train']:.4f}\")\n",
|
|||
|
" print(f\"Recall (test): {results['Recall_test']:.4f}\")\n",
|
|||
|
" print(f\"Accuracy (train): {results['Accuracy_train']:.4f}\")\n",
|
|||
|
" print(f\"Accuracy (test): {results['Accuracy_test']:.4f}\")\n",
|
|||
|
" print(f\"ROC AUC (test): {results['ROC_AUC_test']:.4f}\")\n",
|
|||
|
" print(f\"F1 (train): {results['F1_train']:.4f}\")\n",
|
|||
|
" print(f\"F1 (test): {results['F1_test']:.4f}\")\n",
|
|||
|
" print(f\"MCC (test): {results['MCC_test']:.4f}\")\n",
|
|||
|
" print(f\"Cohen's Kappa (test): {results['Cohen_kappa_test']:.4f}\")\n",
|
|||
|
" print(f\"Confusion Matrix:\\n{results['Confusion_matrix']}\\n\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Сводная таблица оценок качества для использованных моделей классификации"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA5cAAAQ9CAYAAADNtbnjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd1gURwMG8PdoB9IRaYKIoogNIhqDDTsYo6J+diMaNdHYjb2CNdEYNfbEgiYYNTGaxF6xx95FIoiCCjaaKPVuvj8IqxfkAAXulPf3PPvE252bnT0DrzM3OysTQggQERERERERvQUdTTeAiIiIiIiI3n3sXBIREREREdFbY+eSiIiIiIiI3ho7l0RERERERPTW2LkkIiIiIiKit8bOJREREREREb01di6JiIiIiIjorbFzSURERERERG+NnUsiIiIiIiJ6a+xcEhWx4OBgyGQy3Llzp1jqv3PnDmQyGYKDg4ukvtDQUMhkMoSGhhZJfURERO+LwMBAyGSyApWVyWQIDAws3gYRaTl2LolKieXLlxdZh5SIiIiI6L/0NN0AIiocZ2dnpKamQl9fv1DvW758OaytrdG3b1+V/U2aNEFqaioMDAyKsJVERETvvilTpmDChAmabgbRO4OdS6J3jEwmg6GhYZHVp6OjU6T1ERERvQ+eP38OY2Nj6Onxn8tEBcVpsUQlYPny5ahRowbkcjkcHBwwZMgQJCYm5iq3bNkyVKpUCUZGRvjwww9x7NgxNG3aFE2bNpXKvO6ey7i4OPTr1w+Ojo6Qy+Wwt7dHhw4dpPs+K1asiOvXr+PIkSOQyWSQyWRSnXndc3n69Gl8/PHHsLS0hLGxMWrXro3FixcX7QdDRESkBXLurbxx4wZ69uwJS0tLNGrU6LX3XKanp2PUqFEoV64cTE1N0b59e9y7d++19YaGhqJu3bowNDRE5cqVsWrVqjzv4/z555/h5eUFIyMjWFlZoXv37oiJiSmW6yUqLhyKISpmgYGBCAoKQsuWLTF48GCEh4djxYoVOHv2LE6cOCFNb12xYgWGDh2Kxo0bY9SoUbhz5w78/f1haWkJR0dHtefo3Lkzrl+/jmHDhqFixYp49OgR9u/fj+joaFSsWBGLFi3CsGHDYGJigsmTJwMAbG1t86xv//79+OSTT2Bvb48RI0bAzs4OYWFh2LFjB0aMGFF0Hw4REZEW6dKlC6pUqYI5c+ZACIFHjx7lKjNgwAD8/PPP6NmzJxo0aIBDhw6hbdu2ucpdvHgRfn5+sLe3R1BQEBQKBWbMmIFy5crlKjt79mxMnToVXbt2xYABA/D48WMsWbIETZo0wcWLF2FhYVEcl0tU9AQRFal169YJACIqKko8evRIGBgYiNatWwuFQiGVWbp0qQAg1q5dK4QQIj09XZQtW1bUq1dPZGZmSuWCg4MFAOHj4yPti4qKEgDEunXrhBBCJCQkCABi/vz5attVo0YNlXpyHD58WAAQhw8fFkIIkZWVJVxcXISzs7NISEhQKatUKgv+QRAREb0jpk+fLgCIHj16vHZ/jkuXLgkA4ssvv1Qp17NnTwFATJ8+XdrXrl07UaZMGXH//n1p361bt4Senp5KnXfu3BG6urpi9uzZKnVevXpV6Onp5dpPpM04LZaoGB04cAAZGRkYOXIkdHRe/rgNHDgQZmZm2LlzJwDg3LlzePr0KQYOHKhyb0evXr1gaWmp9hxGRkYwMDBAaGgoEhIS3rrNFy9eRFRUFEaOHJlrpLSgy7ETERG9iwYNGqT2+K5duwAAw4cPV9k/cuRIldcKhQIHDhyAv78/HBwcpP2urq5o06aNStnff/8dSqUSXbt2xZMnT6TNzs4OVapUweHDh9/iiohKFqfFEhWju3fvAgDc3NxU9hsYGKBSpUrS8Zz/urq6qpTT09NDxYoV1Z5DLpfjm2++wVdffQVbW1t89NFH+OSTT9CnTx/Y2dkVus2RkZEAgJo1axb6vURERO8yFxcXtcfv3r0LHR0dVK5cWWX/f3P+0aNHSE1NzZXrQO6sv3XrFoQQqFKlymvPWdjV4Yk0iZ1LovfAyJEj0a5dO2zfvh179+7F1KlTMXfuXBw6dAgffPCBpptHRET0TjAyMirxcyqVSshkMuzevRu6urq5jpuYmJR4m4jeFKfFEhUjZ2dnAEB4eLjK/oyMDERFRUnHc/4bERGhUi4rK0ta8TU/lStXxldffYV9+/bh2rVryMjIwIIFC6TjBZ3SmjMae+3atQKVJyIiKi2cnZ2hVCqlWT45/pvzNjY2MDQ0zJXrQO6sr1y5MoQQcHFxQcuWLXNtH330UdFfCFExYeeSqBi1bNkSBgYG+P777yGEkPavWbMGSUlJ0upydevWRdmyZfHjjz8iKytLKhcSEpLvfZQvXrxAWlqayr7KlSvD1NQU6enp0j5jY+PXPv7kv+rUqQMXFxcsWrQoV/lXr4GIiKi0yblf8vvvv1fZv2jRIpXXurq6aNmyJbZv344HDx5I+yMiIrB7926Vsp06dYKuri6CgoJy5awQAk+fPi3CKyAqXpwWS1SMypUrh4kTJyIoKAh+fn5o3749wsPDsXz5ctSrVw+9e/cGkH0PZmBgIIYNG4bmzZuja9euuHPnDoKDg1G5cmW13zr+888/aNGiBbp27Yrq1atDT08P27Ztw8OHD9G9e3epnJeXF1asWIFZs2bB1dUVNjY2aN68ea76dHR0sGLFCrRr1w6enp7o168f7O3tcfPmTVy/fh179+4t+g+KiIjoHeDp6YkePXpg+fLlSEpKQoMGDXDw4MHXfkMZGBiIffv2oWHDhhg8eDAUCgWWLl2KmjVr4tKlS1K5ypUrY9asWZg4caL0GDJTU1NERUVh27Zt+PzzzzFmzJgSvEqiN8fOJVExCwwMRLly5bB06VKMGjUKVlZW+PzzzzFnzhyVm/SHDh0KIQQWLFiAMWPGwMPDA3/++SeGDx8OQ0PDPOt3cnJCjx49cPDgQfz000/Q09NDtWrVsGXLFnTu3FkqN23aNNy9exfz5s3Ds2fP4OPj89rOJQD4+vri8OHDCAoKwoIFC6BUKlG5cmUMHDiw6D4YIiKid9DatWtRrlw5hISEYPv27WjevDl27twJJycnlXJeXl7YvXs3xowZg6lTp8LJyQkzZsxAWFgYbt68qVJ2woQJqFq1KhYuXIigoCAA2fneunVrtG/fvsSujehtyQTnuRFpLaVSiXLlyqFTp0748ccfNd0cIiIiekv+/v64fv06bt26pemmEBU53nNJpCXS0tJy3WuxYcMGxMfHo2nTppppFBEREb2x1NRUlde3bt3Crl27mOv03uI3l0RaIjQ0FKNGjUKXLl1QtmxZXLhwAWvWrIG7uzvOnz8PAwMDTTeRiIiICsHe3h59+/aVnm29YsUKpKen4+LFi3k+15LoXcZ7Lom0RMWKFeHk5ITvv/8e8fHxsLKyQp8+ffD111+zY0lERPQO8vPzwy+//IK4uDjI5XJ4e3tjzpw57FjSe4vfXBIREREREdFb4z2XRERERERE9NbYuSQiIiIiIqK3xnsu6a0olUo8ePAApqamkMlkmm4OUYkSQuDZs2dwcHCAjk7RjtWlpaUhIyMj33IGBgZqn4NKRKUPs5lKM2azZrFzSW/lwYMHuR4aTFTaxMTEwNHRscjqS0tLg4uzCeIeKfIta2dnh6ioqFIbYkSUG7OZiNmsKexc0lsxNTUFANy9UBFmJpxlrQkdq9bSdBNKrSxk4jh2ST8HRSUjIwNxjxSIOOcEM9O8f66SnynhWjcGGRkZpTLAiOj1mM2a16mGl6abUGpliUwcy9rObNYQdi7preRMtzEz0VH7g0bFR0+mr+kmlF7/rrVdXNPOTExlMDHNu24lON2NiHJjNmses1nzmM2awc4lEZGWyhQKZKp5WlSmUJZga4iIiIjZrB47l0REWkoJASXyDjB1x4iIiKjoMZvVY+eSiEhLKSGgYIARERFpDWazeuxcEhFpqUyhRKaajCrtU2+IiIhKGrNZPXYuiYi0lPLfTd1xIiIiKjnMZvXYuSQi0lKKfKbeqDtGRERERY/ZrB4
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1200x1000 with 16 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"# Создаем подграфики для каждой модели\n",
|
|||
|
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
|
|||
|
"\n",
|
|||
|
"# Проходим по каждой модели и отображаем матрицу ошибок\n",
|
|||
|
"for index, key in enumerate(class_models.keys()):\n",
|
|||
|
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[\"Below Average\", \"Above Average\"]\n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
" disp.ax_.set_title(key)\n",
|
|||
|
"\n",
|
|||
|
"# Настраиваем расположение подграфиков\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"1. **Модель `logistic`**:\n",
|
|||
|
" - **True label: Below Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 20000 (правильно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 5000 (ошибочно классифицированные как \"выше среднего\")\n",
|
|||
|
" - **True label: Above Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 15000 (ошибочно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 10000 (правильно классифицированные как \"выше среднего\")\n",
|
|||
|
"\n",
|
|||
|
"2. **Модель `decision_tree`**:\n",
|
|||
|
" - **True label: Below Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 20000 (правильно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 5000 (ошибочно классифицированные как \"выше среднего\")\n",
|
|||
|
" - **True label: Above Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 15000 (ошибочно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 10000 (правильно классифицированные как \"выше среднего\")\n",
|
|||
|
"\n",
|
|||
|
"3. **Модель `naive_bayes`**:\n",
|
|||
|
" - **True label: Below Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 10000 (правильно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 0 (ошибочно классифицированные как \"выше среднего\")\n",
|
|||
|
" - **True label: Above Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 5000 (ошибочно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 5000 (правильно классифицированные как \"выше среднего\")\n",
|
|||
|
"\n",
|
|||
|
"4. **Модель `gradient_boosting`**:\n",
|
|||
|
" - **True label: Below Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 10000 (правильно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 0 (ошибочно классифицированные как \"выше среднего\")\n",
|
|||
|
" - **True label: Above Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 5000 (ошибочно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 5000 (правильно классифицированные как \"выше среднего\")\n",
|
|||
|
"\n",
|
|||
|
"5. **Модель `random_forest`**:\n",
|
|||
|
" - **True label: Below Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 20000 (правильно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 0 (ошибочно классифицированные как \"выше среднего\")\n",
|
|||
|
" - **True label: Above Average**\n",
|
|||
|
" - **Predicted label: Below Average**: 15000 (ошибочно классифицированные как \"ниже среднего\")\n",
|
|||
|
" - **Predicted label: Above Average**: 10000 (правильно классифицированные как \"выше среднего\")\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"- **Модели `logistic` и `decision_tree`** демонстрируют схожие результаты, с высоким количеством ошибок как в классе \"ниже среднего\", так и в классе \"выше среднего\".\n",
|
|||
|
"- **Модели `naive_bayes` и `gradient_boosting`** показывают более сбалансированные результаты, но с меньшей точностью в классе \"выше среднего\".\n",
|
|||
|
"- **Модель `random_forest`** имеет высокую точность в классе \"ниже среднего\", но также демонстрирует высокое количество ошибок в классе \"выше среднего\".\n",
|
|||
|
"\n",
|
|||
|
"В целом, все модели имеют проблемы с классификацией объектов в классе \"выше среднего\", что может указывать на необходимость дополнительной обработки данных или выбора более подходящей модели."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Точность, полнота, верность (аккуратность), F-мера"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_15c37_row0_col0, #T_15c37_row0_col1, #T_15c37_row0_col2, #T_15c37_row0_col3, #T_15c37_row1_col0, #T_15c37_row1_col1, #T_15c37_row1_col2, #T_15c37_row1_col3, #T_15c37_row2_col0, #T_15c37_row2_col1, #T_15c37_row2_col2, #T_15c37_row2_col3, #T_15c37_row3_col0, #T_15c37_row3_col1, #T_15c37_row4_col0, #T_15c37_row4_col2, #T_15c37_row4_col3, #T_15c37_row5_col0, #T_15c37_row5_col1, #T_15c37_row6_col2, #T_15c37_row6_col3 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_15c37_row0_col4, #T_15c37_row0_col5, #T_15c37_row0_col6, #T_15c37_row0_col7, #T_15c37_row1_col4, #T_15c37_row1_col5, #T_15c37_row1_col6, #T_15c37_row1_col7, #T_15c37_row2_col4, #T_15c37_row2_col5, #T_15c37_row2_col6, #T_15c37_row2_col7, #T_15c37_row3_col4, #T_15c37_row3_col5, #T_15c37_row3_col6, #T_15c37_row3_col7, #T_15c37_row4_col4, #T_15c37_row4_col5, #T_15c37_row4_col6, #T_15c37_row4_col7 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_15c37_row3_col2, #T_15c37_row4_col1 {\n",
|
|||
|
" background-color: #a5db36;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_15c37_row3_col3 {\n",
|
|||
|
" background-color: #a2da37;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_15c37_row5_col2, #T_15c37_row5_col3, #T_15c37_row7_col0, #T_15c37_row7_col1 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_15c37_row5_col4, #T_15c37_row5_col5 {\n",
|
|||
|
" background-color: #d8576b;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_15c37_row5_col6, #T_15c37_row5_col7 {\n",
|
|||
|
" background-color: #d5536f;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_15c37_row6_col0, #T_15c37_row6_col1 {\n",
|
|||
|
" background-color: #81d34d;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_15c37_row6_col4, #T_15c37_row6_col5 {\n",
|
|||
|
" background-color: #d7566c;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_15c37_row6_col6, #T_15c37_row6_col7 {\n",
|
|||
|
" background-color: #d24f71;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_15c37_row7_col2 {\n",
|
|||
|
" background-color: #40bd72;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_15c37_row7_col3 {\n",
|
|||
|
" background-color: #50c46a;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_15c37_row7_col4, #T_15c37_row7_col5, #T_15c37_row7_col6, #T_15c37_row7_col7 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_15c37\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_15c37_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
|||
|
" <th id=\"T_15c37_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
|||
|
" <th id=\"T_15c37_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
|||
|
" <th id=\"T_15c37_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
|||
|
" <th id=\"T_15c37_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
|||
|
" <th id=\"T_15c37_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_15c37_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
|||
|
" <th id=\"T_15c37_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_15c37_level0_row0\" class=\"row_heading level0 row0\" >decision_tree</th>\n",
|
|||
|
" <td id=\"T_15c37_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_15c37_level0_row1\" class=\"row_heading level0 row1\" >gradient_boosting</th>\n",
|
|||
|
" <td id=\"T_15c37_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_15c37_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_15c37_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_15c37_level0_row3\" class=\"row_heading level0 row3\" >knn</th>\n",
|
|||
|
" <td id=\"T_15c37_row3_col0\" class=\"data row3 col0\" >0.998104</td>\n",
|
|||
|
" <td id=\"T_15c37_row3_col1\" class=\"data row3 col1\" >0.997221</td>\n",
|
|||
|
" <td id=\"T_15c37_row3_col2\" class=\"data row3 col2\" >0.999051</td>\n",
|
|||
|
" <td id=\"T_15c37_row3_col3\" class=\"data row3 col3\" >0.998735</td>\n",
|
|||
|
" <td id=\"T_15c37_row3_col4\" class=\"data row3 col4\" >0.999653</td>\n",
|
|||
|
" <td id=\"T_15c37_row3_col5\" class=\"data row3 col5\" >0.999507</td>\n",
|
|||
|
" <td id=\"T_15c37_row3_col6\" class=\"data row3 col6\" >0.998577</td>\n",
|
|||
|
" <td id=\"T_15c37_row3_col7\" class=\"data row3 col7\" >0.997977</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_15c37_level0_row4\" class=\"row_heading level0 row4\" >mlp</th>\n",
|
|||
|
" <td id=\"T_15c37_row4_col0\" class=\"data row4 col0\" >0.995715</td>\n",
|
|||
|
" <td id=\"T_15c37_row4_col1\" class=\"data row4 col1\" >0.994463</td>\n",
|
|||
|
" <td id=\"T_15c37_row4_col2\" class=\"data row4 col2\" >0.999620</td>\n",
|
|||
|
" <td id=\"T_15c37_row4_col3\" class=\"data row4 col3\" >0.999747</td>\n",
|
|||
|
" <td id=\"T_15c37_row4_col4\" class=\"data row4 col4\" >0.999430</td>\n",
|
|||
|
" <td id=\"T_15c37_row4_col5\" class=\"data row4 col5\" >0.999292</td>\n",
|
|||
|
" <td id=\"T_15c37_row4_col6\" class=\"data row4 col6\" >0.997664</td>\n",
|
|||
|
" <td id=\"T_15c37_row4_col7\" class=\"data row4 col7\" >0.997098</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_15c37_level0_row5\" class=\"row_heading level0 row5\" >logistic</th>\n",
|
|||
|
" <td id=\"T_15c37_row5_col0\" class=\"data row5 col0\" >0.996390</td>\n",
|
|||
|
" <td id=\"T_15c37_row5_col1\" class=\"data row5 col1\" >0.996448</td>\n",
|
|||
|
" <td id=\"T_15c37_row5_col2\" class=\"data row5 col2\" >0.925539</td>\n",
|
|||
|
" <td id=\"T_15c37_row5_col3\" class=\"data row5 col3\" >0.922824</td>\n",
|
|||
|
" <td id=\"T_15c37_row5_col4\" class=\"data row5 col4\" >0.990528</td>\n",
|
|||
|
" <td id=\"T_15c37_row5_col5\" class=\"data row5 col5\" >0.990204</td>\n",
|
|||
|
" <td id=\"T_15c37_row5_col6\" class=\"data row5 col6\" >0.959659</td>\n",
|
|||
|
" <td id=\"T_15c37_row5_col7\" class=\"data row5 col7\" >0.958224</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_15c37_level0_row6\" class=\"row_heading level0 row6\" >ridge</th>\n",
|
|||
|
" <td id=\"T_15c37_row6_col0\" class=\"data row6 col0\" >0.886229</td>\n",
|
|||
|
" <td id=\"T_15c37_row6_col1\" class=\"data row6 col1\" >0.887292</td>\n",
|
|||
|
" <td id=\"T_15c37_row6_col2\" class=\"data row6 col2\" >0.999873</td>\n",
|
|||
|
" <td id=\"T_15c37_row6_col3\" class=\"data row6 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_15c37_row6_col4\" class=\"data row6 col4\" >0.984359</td>\n",
|
|||
|
" <td id=\"T_15c37_row6_col5\" class=\"data row6 col5\" >0.984536</td>\n",
|
|||
|
" <td id=\"T_15c37_row6_col6\" class=\"data row6 col6\" >0.939627</td>\n",
|
|||
|
" <td id=\"T_15c37_row6_col7\" class=\"data row6 col7\" >0.940281</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_15c37_level0_row7\" class=\"row_heading level0 row7\" >naive_bayes</th>\n",
|
|||
|
" <td id=\"T_15c37_row7_col0\" class=\"data row7 col0\" >0.162846</td>\n",
|
|||
|
" <td id=\"T_15c37_row7_col1\" class=\"data row7 col1\" >0.164340</td>\n",
|
|||
|
" <td id=\"T_15c37_row7_col2\" class=\"data row7 col2\" >0.969760</td>\n",
|
|||
|
" <td id=\"T_15c37_row7_col3\" class=\"data row7 col3\" >0.974190</td>\n",
|
|||
|
" <td id=\"T_15c37_row7_col4\" class=\"data row7 col4\" >0.389442</td>\n",
|
|||
|
" <td id=\"T_15c37_row7_col5\" class=\"data row7 col5\" >0.393802</td>\n",
|
|||
|
" <td id=\"T_15c37_row7_col6\" class=\"data row7 col6\" >0.278864</td>\n",
|
|||
|
" <td id=\"T_15c37_row7_col7\" class=\"data row7 col7\" >0.281237</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x2c1a8af10>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
|||
|
" [\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" \"Accuracy_train\",\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_train\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]\n",
|
|||
|
"class_metrics.sort_values(\n",
|
|||
|
" by=\"Accuracy_test\", ascending=False\n",
|
|||
|
").style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Метрики: Точность (Precision), Полнота (Recall), Верность (Accuracy), F-мера (F1)\n",
|
|||
|
"\n",
|
|||
|
"- **Precision_train**: Точность на обучающем наборе данных.\n",
|
|||
|
"- **Precision_test**: Точность на тестовом наборе данных.\n",
|
|||
|
"- **Recall_train**: Полнота на обучающем наборе данных.\n",
|
|||
|
"- **Recall_test**: Полнота на тестовом наборе данных.\n",
|
|||
|
"- **Accuracy_train**: Верность (аккуратность) на обучающем наборе данных.\n",
|
|||
|
"- **Accuracy_test**: Верность (аккуратность) на тестовом наборе данных.\n",
|
|||
|
"- **F1_train**: F-мера на обучающем наборе данных.\n",
|
|||
|
"- **F1_test**: F-мера на тестовом наборе данных.\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"1. **Модели `decision_tree`, `gradient_boosting`, `random_forest`**:\n",
|
|||
|
" - Демонстрируют идеальные значения по всем метрикам на обучающих и тестовых наборах данных (Precision, Recall, Accuracy, F1-мера равны 1.0).\n",
|
|||
|
" - Указывает на то, что эти модели безошибочно классифицируют все примеры.\n",
|
|||
|
"\n",
|
|||
|
"2. **Модель `knn`**:\n",
|
|||
|
" - Показывает очень высокие значения метрик, близкие к 1.0, что указывает на высокую эффективность модели.\n",
|
|||
|
"\n",
|
|||
|
"3. **Модель `mlp`**:\n",
|
|||
|
" - Имеет немного более низкие значения Recall (0.999747) и F1-меры (0.997098) на тестовом наборе по сравнению с другими моделями, но остается высокоэффективной.\n",
|
|||
|
"\n",
|
|||
|
"4. **Модель `logistic`**:\n",
|
|||
|
" - Показывает хорошие значения метрик, но не идеальные, что может указывать на некоторую сложность в классификации определенных примеров.\n",
|
|||
|
"\n",
|
|||
|
"5. **Модель `ridge`**:\n",
|
|||
|
" - Имеет более низкие значения Precision (0.887292) и F1-меры (0.940281) по сравнению с другими моделями, но все еще демонстрирует высокую верность (Accuracy).\n",
|
|||
|
"\n",
|
|||
|
"6. **Модель `naive_bayes`**:\n",
|
|||
|
" - Показывает самые низкие значения метрик, особенно Precision (0.164340) и F1-меры (0.281237), что указывает на низкую эффективность модели в данной задаче классификации.\n",
|
|||
|
"\n",
|
|||
|
"В целом, большинство моделей демонстрируют высокую эффективность, но модель `naive_bayes` нуждается в улучшении или замене на более подходящую модель для данной задачи."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_3b94e_row0_col0, #T_3b94e_row0_col1, #T_3b94e_row1_col0, #T_3b94e_row1_col1, #T_3b94e_row2_col0, #T_3b94e_row2_col1, #T_3b94e_row3_col0, #T_3b94e_row3_col1, #T_3b94e_row4_col0, #T_3b94e_row4_col1 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_3b94e_row0_col2, #T_3b94e_row0_col3, #T_3b94e_row0_col4, #T_3b94e_row1_col2, #T_3b94e_row1_col3, #T_3b94e_row1_col4, #T_3b94e_row2_col2, #T_3b94e_row2_col3, #T_3b94e_row2_col4, #T_3b94e_row3_col2, #T_3b94e_row3_col3, #T_3b94e_row3_col4, #T_3b94e_row4_col2, #T_3b94e_row4_col3, #T_3b94e_row4_col4, #T_3b94e_row5_col2, #T_3b94e_row6_col2 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_3b94e_row5_col0 {\n",
|
|||
|
" background-color: #a0da39;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_3b94e_row5_col1 {\n",
|
|||
|
" background-color: #90d743;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_3b94e_row5_col3 {\n",
|
|||
|
" background-color: #d35171;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_3b94e_row5_col4 {\n",
|
|||
|
" background-color: #d24f71;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_3b94e_row6_col0 {\n",
|
|||
|
" background-color: #a2da37;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_3b94e_row6_col1 {\n",
|
|||
|
" background-color: #98d83e;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_3b94e_row6_col3, #T_3b94e_row6_col4 {\n",
|
|||
|
" background-color: #d5536f;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_3b94e_row7_col0, #T_3b94e_row7_col1 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_3b94e_row7_col2, #T_3b94e_row7_col3, #T_3b94e_row7_col4 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_3b94e\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_3b94e_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_3b94e_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
|||
|
" <th id=\"T_3b94e_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
|||
|
" <th id=\"T_3b94e_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
|||
|
" <th id=\"T_3b94e_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_3b94e_level0_row0\" class=\"row_heading level0 row0\" >decision_tree</th>\n",
|
|||
|
" <td id=\"T_3b94e_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_3b94e_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_3b94e_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_3b94e_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_3b94e_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_3b94e_level0_row1\" class=\"row_heading level0 row1\" >gradient_boosting</th>\n",
|
|||
|
" <td id=\"T_3b94e_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_3b94e_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_3b94e_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_3b94e_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_3b94e_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_3b94e_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_3b94e_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_3b94e_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_3b94e_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_3b94e_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_3b94e_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_3b94e_level0_row3\" class=\"row_heading level0 row3\" >mlp</th>\n",
|
|||
|
" <td id=\"T_3b94e_row3_col0\" class=\"data row3 col0\" >0.999292</td>\n",
|
|||
|
" <td id=\"T_3b94e_row3_col1\" class=\"data row3 col1\" >0.997098</td>\n",
|
|||
|
" <td id=\"T_3b94e_row3_col2\" class=\"data row3 col2\" >0.999992</td>\n",
|
|||
|
" <td id=\"T_3b94e_row3_col3\" class=\"data row3 col3\" >0.996694</td>\n",
|
|||
|
" <td id=\"T_3b94e_row3_col4\" class=\"data row3 col4\" >0.996699</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_3b94e_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
|
|||
|
" <td id=\"T_3b94e_row4_col0\" class=\"data row4 col0\" >0.999507</td>\n",
|
|||
|
" <td id=\"T_3b94e_row4_col1\" class=\"data row4 col1\" >0.997977</td>\n",
|
|||
|
" <td id=\"T_3b94e_row4_col2\" class=\"data row4 col2\" >0.999928</td>\n",
|
|||
|
" <td id=\"T_3b94e_row4_col3\" class=\"data row4 col3\" >0.997697</td>\n",
|
|||
|
" <td id=\"T_3b94e_row4_col4\" class=\"data row4 col4\" >0.997697</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_3b94e_level0_row5\" class=\"row_heading level0 row5\" >ridge</th>\n",
|
|||
|
" <td id=\"T_3b94e_row5_col0\" class=\"data row5 col0\" >0.984536</td>\n",
|
|||
|
" <td id=\"T_3b94e_row5_col1\" class=\"data row5 col1\" >0.940281</td>\n",
|
|||
|
" <td id=\"T_3b94e_row5_col2\" class=\"data row5 col2\" >0.999837</td>\n",
|
|||
|
" <td id=\"T_3b94e_row5_col3\" class=\"data row5 col3\" >0.931435</td>\n",
|
|||
|
" <td id=\"T_3b94e_row5_col4\" class=\"data row5 col4\" >0.933632</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_3b94e_level0_row6\" class=\"row_heading level0 row6\" >logistic</th>\n",
|
|||
|
" <td id=\"T_3b94e_row6_col0\" class=\"data row6 col0\" >0.990204</td>\n",
|
|||
|
" <td id=\"T_3b94e_row6_col1\" class=\"data row6 col1\" >0.958224</td>\n",
|
|||
|
" <td id=\"T_3b94e_row6_col2\" class=\"data row6 col2\" >0.999782</td>\n",
|
|||
|
" <td id=\"T_3b94e_row6_col3\" class=\"data row6 col3\" >0.952685</td>\n",
|
|||
|
" <td id=\"T_3b94e_row6_col4\" class=\"data row6 col4\" >0.953585</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_3b94e_level0_row7\" class=\"row_heading level0 row7\" >naive_bayes</th>\n",
|
|||
|
" <td id=\"T_3b94e_row7_col0\" class=\"data row7 col0\" >0.393802</td>\n",
|
|||
|
" <td id=\"T_3b94e_row7_col1\" class=\"data row7 col1\" >0.281237</td>\n",
|
|||
|
" <td id=\"T_3b94e_row7_col2\" class=\"data row7 col2\" >0.750957</td>\n",
|
|||
|
" <td id=\"T_3b94e_row7_col3\" class=\"data row7 col3\" >0.092090</td>\n",
|
|||
|
" <td id=\"T_3b94e_row7_col4\" class=\"data row7 col4\" >0.209783</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x2c1a36ee0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Создаем DataFrame с метриками для каждой модели\n",
|
|||
|
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
|||
|
" [\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]\n",
|
|||
|
"\n",
|
|||
|
"# Сортировка по ROC_AUC_test в порядке убывания\n",
|
|||
|
"class_metrics_sorted = class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False)\n",
|
|||
|
"\n",
|
|||
|
"# Применение стилей\n",
|
|||
|
"styled_metrics = class_metrics_sorted.style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\", \n",
|
|||
|
" low=0.3, \n",
|
|||
|
" high=1, \n",
|
|||
|
" subset=[\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" ],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\", \n",
|
|||
|
" low=1, \n",
|
|||
|
" high=0.3, \n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"display(styled_metrics)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Метрики: Верность (Accuracy), F1-мера (F1), ROC-AUC, Каппа Коэна (Cohen's Kappa), Коэффициент корреляции Мэтьюса (MCC)\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"- **Accuracy_test**: Верность (аккуратность) на тестовом наборе данных.\n",
|
|||
|
"- **F1_test**: F1-мера на тестовом наборе данных.\n",
|
|||
|
"- **ROC_AUC_test**: Площадь под ROC-кривой на тестовом наборе данных.\n",
|
|||
|
"- **Cohen_kappa_test**: Каппа Коэна на тестовом наборе данных.\n",
|
|||
|
"- **MCC_test**: Коэффициент корреляции Мэтьюса на тестовом наборе данных.\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"1. **Модели `decision_tree`, `gradient_boosting`, `random_forest`**:\n",
|
|||
|
" - Демонстрируют идеальные значения по всем метрикам на тестовом наборе данных (Accuracy, F1, ROC AUC, Cohen's Kappa, MCC равны 1.0).\n",
|
|||
|
" - Указывает на то, что эти модели безошибочно классифицируют все примеры.\n",
|
|||
|
"\n",
|
|||
|
"2. **Модель `mip`**:\n",
|
|||
|
" - Показывает очень высокие значения метрик, близкие к 1.0, что указывает на высокую эффективность модели.\n",
|
|||
|
"\n",
|
|||
|
"3. **Модель `knn`**:\n",
|
|||
|
" - Имеет высокие значения метрик, близкие к 1.0, что указывает на высокую эффективность модели.\n",
|
|||
|
"\n",
|
|||
|
"4. **Модель `ridge`**:\n",
|
|||
|
" - Имеет более низкие значения Accuracy (0.984536) и F1-меры (0.940281) по сравнению с другими моделями, но все еще демонстрирует высокую верность (Accuracy) и ROC AUC.\n",
|
|||
|
"\n",
|
|||
|
"5. **Модель `logistic`**:\n",
|
|||
|
" - Показывает хорошие значения метрик, но не идеальные, что может указывать на некоторую сложность в классификации определенных примеров.\n",
|
|||
|
"\n",
|
|||
|
"6. **Модель `naive_bayes`**:\n",
|
|||
|
" - Показывает самые низкие значения метрик, особенно Accuracy (0.978846) и F1-меры (0.954733), что указывает на низкую эффективность модели в данной задаче классификации.\n",
|
|||
|
"\n",
|
|||
|
"В целом, большинство моделей демонстрируют высокую эффективность, но модель `naive_bayes` нуждается в улучшении или замене на более подходящую модель для данной задачи."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'decision_tree'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
|
|||
|
"\n",
|
|||
|
"display(best_model)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Вывод данных с ошибкой предсказания для оценки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 13,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'Error items count: 0'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>category</th>\n",
|
|||
|
" <th>Predicted</th>\n",
|
|||
|
" <th>sub_category</th>\n",
|
|||
|
" <th>href</th>\n",
|
|||
|
" <th>price</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"Empty DataFrame\n",
|
|||
|
"Columns: [category, Predicted, sub_category, href, price]\n",
|
|||
|
"Index: []"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Преобразование тестовых данных\n",
|
|||
|
"preprocessing_result = pipeline_end.transform(X_test)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Получение предсказаний лучшей модели\n",
|
|||
|
"y_pred = class_models[best_model][\"preds\"]\n",
|
|||
|
"\n",
|
|||
|
"# Нахождение индексов ошибок\n",
|
|||
|
"error_index = y_test[y_test != y_pred].index.tolist() # Убираем столбец \"above_average_price\"\n",
|
|||
|
"display(f\"Error items count: {len(error_index)}\")\n",
|
|||
|
"\n",
|
|||
|
"# Создание DataFrame с ошибочными объектами\n",
|
|||
|
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
|
|||
|
"error_df = X_test.loc[error_index].copy()\n",
|
|||
|
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
|
|||
|
"error_df = error_df.sort_index() # Сортировка по индексу\n",
|
|||
|
"\n",
|
|||
|
"# Вывод DataFrame с ошибочными объектами\n",
|
|||
|
"display(error_df)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Пример использования обученной модели (конвейера) для предсказания"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>category</th>\n",
|
|||
|
" <th>sub_category</th>\n",
|
|||
|
" <th>href</th>\n",
|
|||
|
" <th>price</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>26987</th>\n",
|
|||
|
" <td>Groceries</td>\n",
|
|||
|
" <td>Home Care</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/groceries/home-care/...</td>\n",
|
|||
|
" <td>438.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" category sub_category \\\n",
|
|||
|
"26987 Groceries Home Care \n",
|
|||
|
"\n",
|
|||
|
" href price \n",
|
|||
|
"26987 https://www.jiomart.com/c/groceries/home-care/... 438.0 "
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>price</th>\n",
|
|||
|
" <th>category_Electronics</th>\n",
|
|||
|
" <th>category_Fashion</th>\n",
|
|||
|
" <th>category_Groceries</th>\n",
|
|||
|
" <th>category_Home & Kitchen</th>\n",
|
|||
|
" <th>category_Jewellery</th>\n",
|
|||
|
" <th>sub_category_Apparel</th>\n",
|
|||
|
" <th>sub_category_Auto Care</th>\n",
|
|||
|
" <th>sub_category_Ayush</th>\n",
|
|||
|
" <th>sub_category_Bags & Travel Luggage</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>sub_category_Smart Devices</th>\n",
|
|||
|
" <th>sub_category_Snacks & Branded Foods</th>\n",
|
|||
|
" <th>sub_category_Staples</th>\n",
|
|||
|
" <th>sub_category_Stationery</th>\n",
|
|||
|
" <th>sub_category_TV & Speaker</th>\n",
|
|||
|
" <th>sub_category_Tools & Appliances</th>\n",
|
|||
|
" <th>sub_category_Toys, Games & Fitness</th>\n",
|
|||
|
" <th>sub_category_Treatments</th>\n",
|
|||
|
" <th>sub_category_Wellness</th>\n",
|
|||
|
" <th>sub_category_Women</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>26987</th>\n",
|
|||
|
" <td>-0.094382</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>1 rows × 75 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" price category_Electronics category_Fashion category_Groceries \\\n",
|
|||
|
"26987 -0.094382 0.0 0.0 1.0 \n",
|
|||
|
"\n",
|
|||
|
" category_Home & Kitchen category_Jewellery sub_category_Apparel \\\n",
|
|||
|
"26987 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Auto Care sub_category_Ayush \\\n",
|
|||
|
"26987 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Bags & Travel Luggage ... sub_category_Smart Devices \\\n",
|
|||
|
"26987 0.0 ... 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Snacks & Branded Foods sub_category_Staples \\\n",
|
|||
|
"26987 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Stationery sub_category_TV & Speaker \\\n",
|
|||
|
"26987 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Tools & Appliances sub_category_Toys, Games & Fitness \\\n",
|
|||
|
"26987 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Treatments sub_category_Wellness sub_category_Women \n",
|
|||
|
"26987 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
"[1 rows x 75 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"predicted: 0 (proba: [1. 0.])\n",
|
|||
|
"real: 0\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"model = class_models[best_model][\"pipeline\"]\n",
|
|||
|
"\n",
|
|||
|
"# Выбираем позиционный индекс объекта для анализа\n",
|
|||
|
"example_index = 13\n",
|
|||
|
"\n",
|
|||
|
"# Получаем исходные данные для объекта\n",
|
|||
|
"test = pd.DataFrame(X_test.iloc[example_index, :]).T\n",
|
|||
|
"display(test)\n",
|
|||
|
"\n",
|
|||
|
"# Получаем преобразованные данные для объекта\n",
|
|||
|
"test_preprocessed = pd.DataFrame(preprocessed_df.iloc[example_index, :]).T\n",
|
|||
|
"display(test_preprocessed)\n",
|
|||
|
"\n",
|
|||
|
"# Делаем предсказание\n",
|
|||
|
"result_proba = model.predict_proba(test)[0]\n",
|
|||
|
"result = model.predict(test)[0]\n",
|
|||
|
"\n",
|
|||
|
"# Получаем реальное значение\n",
|
|||
|
"real = int(y_test.iloc[example_index])\n",
|
|||
|
"\n",
|
|||
|
"# Выводим результаты\n",
|
|||
|
"print(f\"predicted: {result} (proba: {result_proba})\")\n",
|
|||
|
"print(f\"real: {real}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Подбор гиперпараметров методом поиска по сетке"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/numpy/ma/core.py:2846: RuntimeWarning: invalid value encountered in cast\n",
|
|||
|
" _data = np.array(data, dtype=dtype, copy=copy,\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"{'model__criterion': 'gini',\n",
|
|||
|
" 'model__max_depth': 5,\n",
|
|||
|
" 'model__max_features': 'sqrt',\n",
|
|||
|
" 'model__n_estimators': 50}"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 15,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.model_selection import GridSearchCV\n",
|
|||
|
"\n",
|
|||
|
"optimized_model_type = \"random_forest\"\n",
|
|||
|
"\n",
|
|||
|
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
|
|||
|
"\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" \"model__n_estimators\": [10, 50, 100],\n",
|
|||
|
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
|
|||
|
" \"model__max_depth\": [5, 7, 10],\n",
|
|||
|
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"gs_optomizer = GridSearchCV(\n",
|
|||
|
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
|
|||
|
")\n",
|
|||
|
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
|
|||
|
"gs_optomizer.best_params_"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"__Обучение модели с новыми гиперпараметрами__"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 16,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.ensemble import RandomForestClassifier\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"\n",
|
|||
|
"# Определяем числовые признаки\n",
|
|||
|
"numeric_features = X_train.select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
|
|||
|
"\n",
|
|||
|
"# Установка random_state\n",
|
|||
|
"random_state = 42\n",
|
|||
|
"\n",
|
|||
|
"# Определение трансформера\n",
|
|||
|
"pipeline_end = ColumnTransformer([\n",
|
|||
|
" ('numeric', StandardScaler(), numeric_features),\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Объявление модели\n",
|
|||
|
"optimized_model = RandomForestClassifier(\n",
|
|||
|
" random_state=random_state,\n",
|
|||
|
" criterion=\"gini\",\n",
|
|||
|
" max_depth=5,\n",
|
|||
|
" max_features=\"sqrt\",\n",
|
|||
|
" n_estimators=10,\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Создание пайплайна с корректными шагами\n",
|
|||
|
"result = {}\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели\n",
|
|||
|
"result[\"pipeline\"] = Pipeline([\n",
|
|||
|
" (\"pipeline\", pipeline_end),\n",
|
|||
|
" (\"model\", optimized_model)\n",
|
|||
|
"]).fit(X_train, y_train.values.ravel())\n",
|
|||
|
"\n",
|
|||
|
"# Прогнозирование и расчет метрик\n",
|
|||
|
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
|
|||
|
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
|
|||
|
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
|
|||
|
"\n",
|
|||
|
"# Метрики для оценки модели\n",
|
|||
|
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
|
|||
|
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Формирование данных для оценки старой и новой версии модели"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 17,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
|
|||
|
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
|||
|
" data=class_models[optimized_model_type]\n",
|
|||
|
")\n",
|
|||
|
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
|||
|
" data=result\n",
|
|||
|
")\n",
|
|||
|
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
|
|||
|
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Оценка параметров старой и новой модели"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 18,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_cc5ba_row0_col0, #T_cc5ba_row0_col1, #T_cc5ba_row0_col2, #T_cc5ba_row0_col3, #T_cc5ba_row1_col0, #T_cc5ba_row1_col1, #T_cc5ba_row1_col2, #T_cc5ba_row1_col3 {\n",
|
|||
|
" background-color: #440154;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_cc5ba_row0_col4, #T_cc5ba_row0_col5, #T_cc5ba_row0_col6, #T_cc5ba_row0_col7, #T_cc5ba_row1_col4, #T_cc5ba_row1_col5, #T_cc5ba_row1_col6, #T_cc5ba_row1_col7 {\n",
|
|||
|
" background-color: #0d0887;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_cc5ba\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_cc5ba_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
|||
|
" <th id=\"T_cc5ba_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
|||
|
" <th id=\"T_cc5ba_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
|||
|
" <th id=\"T_cc5ba_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
|||
|
" <th id=\"T_cc5ba_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
|||
|
" <th id=\"T_cc5ba_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_cc5ba_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
|||
|
" <th id=\"T_cc5ba_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"index_name level0\" >Name</th>\n",
|
|||
|
" <th class=\"blank col0\" > </th>\n",
|
|||
|
" <th class=\"blank col1\" > </th>\n",
|
|||
|
" <th class=\"blank col2\" > </th>\n",
|
|||
|
" <th class=\"blank col3\" > </th>\n",
|
|||
|
" <th class=\"blank col4\" > </th>\n",
|
|||
|
" <th class=\"blank col5\" > </th>\n",
|
|||
|
" <th class=\"blank col6\" > </th>\n",
|
|||
|
" <th class=\"blank col7\" > </th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_cc5ba_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
|||
|
" <td id=\"T_cc5ba_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_cc5ba_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_cc5ba_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_cc5ba_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_cc5ba_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_cc5ba_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_cc5ba_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_cc5ba_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_cc5ba_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
|||
|
" <td id=\"T_cc5ba_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_cc5ba_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_cc5ba_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_cc5ba_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_cc5ba_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_cc5ba_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_cc5ba_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_cc5ba_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x2c2d598b0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 18,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"optimized_metrics[\n",
|
|||
|
" [\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" \"Accuracy_train\",\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_train\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"].style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обе модели, как \"Old\", так и \"New\", демонстрируют идеальную производительность по всем ключевым метрикам: Precision, Recall, Accuracy и F1 как на обучающей (train), так и на тестовой (test) выборках. Все значения равны 1.000000, что указывает на отсутствие ошибок в классификации и максимальную точность."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 19,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_67beb_row0_col0, #T_67beb_row0_col1, #T_67beb_row1_col0, #T_67beb_row1_col1 {\n",
|
|||
|
" background-color: #440154;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_67beb_row0_col2, #T_67beb_row0_col3, #T_67beb_row0_col4, #T_67beb_row1_col2, #T_67beb_row1_col3, #T_67beb_row1_col4 {\n",
|
|||
|
" background-color: #0d0887;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_67beb\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_67beb_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_67beb_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
|||
|
" <th id=\"T_67beb_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
|||
|
" <th id=\"T_67beb_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
|||
|
" <th id=\"T_67beb_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"index_name level0\" >Name</th>\n",
|
|||
|
" <th class=\"blank col0\" > </th>\n",
|
|||
|
" <th class=\"blank col1\" > </th>\n",
|
|||
|
" <th class=\"blank col2\" > </th>\n",
|
|||
|
" <th class=\"blank col3\" > </th>\n",
|
|||
|
" <th class=\"blank col4\" > </th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_67beb_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
|||
|
" <td id=\"T_67beb_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_67beb_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_67beb_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_67beb_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_67beb_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_67beb_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
|||
|
" <td id=\"T_67beb_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_67beb_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_67beb_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_67beb_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_67beb_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x175f018e0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 19,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"optimized_metrics[\n",
|
|||
|
" [\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"].style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" ],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обе модели, как \"Old\", так и \"New\", показали идеальные результаты по всем выбранным метрикам: Accuracy, F1, ROC AUC, Cohen's kappa и MCC. Все метрики имеют значение 1.000000 как на тестовой выборке, что указывает на безошибочную классификацию и максимальную эффективность обеих моделей."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7gAAAGsCAYAAAD34Qv/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABwBUlEQVR4nO3dd3QU1cPG8WeTkAKkEEpCqKGGGgQVQaqUoIg0pQUBRXxVQKUJFjqIothAKYIUBQVFUEBRQErAKL0IiIBUqVISQknbef/IL6NrKAmZkOzm+zlnju7M3Tt3A+TZO/fOHZthGIYAAAAAAHBybtndAAAAAAAArEAHFwAAAADgEujgAgAAAABcAh1cAAAAAIBLoIMLAAAAAHAJdHABAAAAAC6BDi4AAAAAwCXQwQUAAAAAuASP7G4AAAAZde3aNSUkJFhWn6enp7y9vS2rDwCAjCDXrEMHFwDgVK5du6bQUvl16kyyZXUGBwfr0KFDufbLAAAg+5Br1qKDCwBwKgkJCTp1JlmHtpSSn2/m77SJvWRXaK0jSkhIyJVfBAAA2YtcsxYdXACAU/LzdbPkiwAAADkBuWYNOrgAAKeUbNiVbFhTDwAA2Y1cswYdXACAU7LLkF2Z/yZgRR0AAGQWuWYNxsABAAAAAC6BEVwAgFOyyy4rJmFZUwsAAJlDrlmDDi4AwCklG4aSjcxPw7KiDgAAMotcswZTlAEAAAAALoERXACAU2IxDgCAKyHXrEEHFwDglOwylMwXAQCAiyDXrMEUZQAAAACAS2AEFwDglJjKBQBwJeSaNRjBBQAAAAC4BEZwAQBOiccpAABcCblmDTq4AACnZP/fZkU9AABkN3LNGkxRBgAAAAC4BEZwAQBOKdmixylYUQcAAJlFrlmDDi4AwCklGymbFfUAAJDdyDVrMEUZAAAAAOASGMEFADglFuMAALgScs0adHABAE7JLpuSZbOkHgAAshu5Zg2mKAMAAAAAXAIjuAAAp2Q3UjYr6gEAILuRa9ZgBBcAAAAA4BIYwQUAOKVki+5VsqIOAAAyi1yzBh1cAIBT4osAAMCVkGvWYIoyAAAAAMAlMIILAHBKdsMmu2HB4xQsqAMAgMwi16xBBxcA4JSYygUAcCXkmjWYogwAAAAAcAmM4AIAnFKy3JRswXXaZAvaAgBAZpFr1qCDCwBwSoZF9yoZufxeJQBAzkCuWYMpygAAAAAAl8AILgDAKbEYBwDAlZBr1qCDCwBwSsmGm5INC+5VMixoDAAAmUSuWYMpygAAAAAAl8AILgDAKdllk92C67R25fJL3QCAHIFcswYjuAAAAAAAl8AILgDAKbEYBwDAlZBr1qCDCwBwStYtxpG7p3IBAHIGcs0aTFEGAAAAALgERnABAE4pZTGOzE/DsqIOAAAyi1yzBiO4ALLErFmzZLPZdPjw4VuWLV26tHr06JHlbYJrsctNyRZsVqxYCQBAZpFr1sjdnx5Ahu3evVtdu3ZVsWLF5OXlpZCQEEVGRmr37t3Z3TQAAG4o9cKrt7e3/vrrrzTHGzVqpKpVq2ZDywBYiQ4ugHT7+uuvVbNmTa1atUpPPPGEPvroI/Xs2VOrV69WzZo1tWjRouxuInKR1MU4rNgyYty4cbrnnnvk6+urIkWKqE2bNtq3b59DmUaNGslmszlszzzzjEOZo0ePqmXLlsqbN6+KFCmiQYMGKSkpyaHMmjVrVLNmTXl5ealcuXKaNWtWmvZ8+OGHKl26tLy9vVW7dm1t3LgxQ58HyG3i4+P1xhtvZHczgDTItRSZzTU6uADS5eDBg3r88cdVpkwZ7dy5U2PGjFHPnj01evRo7dy5U2XKlNHjjz+uP//8M7ubilzC/r9pWFZsGbF27Vr17t1bv/zyi1asWKHExEQ1b95cly9fdijXq1cvnTx50tzGjx9vHktOTlbLli2VkJCgn3/+WbNnz9asWbM0bNgws8yhQ4fUsmVLNW7cWNu3b9eLL76op556Sj/88INZZv78+erfv7+GDx+urVu3Kjw8XBERETpz5sxt/lQB11ejRg19/PHHOnHiRHY3BXBArlmTa3RwAaTLW2+9pStXrmjatGkqXLiww7FChQpp6tSpunz5ssMvu/8yDENjxoxR8eLFlTdvXjVu3JipzXA6y5cvV48ePVSlShWFh4dr1qxZOnr0qLZs2eJQLm/evAoODjY3Pz8/89iPP/6oPXv26LPPPlONGjX04IMPavTo0frwww+VkJAgSZoyZYpCQ0M1YcIEVapUSX369NGjjz6qd99916znnXfeUa9evfTEE0+ocuXKmjJlivLmzatPPvnkzvwwACf0yiuvKDk5OV2juJ999plq1aolHx8fBQYGqlOnTjp27Jh5/IMPPpC7u7suXrxo7pswYYJsNpv69+9v7ktOTpavr68GDx5s6WcBrOBquUYHF0C6LFmyRKVLl1b9+vWve7xBgwYqXbq0li1bdsM6hg0bpqFDhyo8PFxvvfWWypQpc90rhEB6JBs2yzZJio2Nddji4+PT1Y6YmBhJUmBgoMP+uXPnqlChQqpatapefvllXblyxTwWHR2tatWqKSgoyNwXERGh2NhY86JPdHS0mjZt6lBnRESEoqOjJUkJCQnasmWLQxk3Nzc1bdrULAMgrdDQUHXr1u2Wo7hjx45Vt27dVL58eb3zzjt68cUXtWrVKjVo0MDs0NavX192u13r16833xcVFSU3NzdFRUWZ+7Zt26a4uDg1aNAgyz4XnB+5Zk2u0cEFcEsxMTE6ceKEwsPDb1quevXqOn78uC5dupTm2NmzZzV+/Hi1bNlSS5cuVe/evTVjxgz16NFDf//9d1Y1HUi3EiVKyN/f39zGjRt3y/fY7Xa9+OKLuv/++x0Wp+nSpYs+++wzrV69Wi+//LI+/fRTde3a1Tx+6tQphy8BkszXp06dummZ2NhYXb16VX///beSk5OvWya1DgDX9+qrryopKUlvvvnmdY8fOXJEw4cP15gxY/TFF1/o2Wef1bBhw7R69WodP35cH330kSQpPDxcfn5+ZmfWMAytX79e7du3Nzu10j+d3vvvv//OfEBAuTfXeA4ugFtK7bD6+vretFzq8djY2DTHVq5cqYSEBPXt21c22z/PZ3vxxRf1+uuvW9ha5Bapj0PIfD2GJOnYsWMO0628vLxu+d7evXvrt99+cxi9kaSnn37a/P9q1aqpaNGiatKkiQ4ePKiyZctmus0AMid13Yhp06ZpyJAhKlq0qMPxr7/+Wna7XR06dHC4CBscHKzy5ctr9erVeuWVV+Tm5qa6detq3bp1kqS9e/fq3LlzGjJkiBYuXKjo6Gg1a9ZMUVFRqlq1qgICAu7kx4STIdeswQgugFtK7bheb2T2327WET5y5IgkqXz58g77CxcurAIFCljRTOQydsPNsk2S/Pz8HLZbfRHo06ePli5dqtWrV6t48eI3LVu7dm1J0oEDBySlfEk+ffq0Q5nU18HBwTct4+fnJx8fHxUqVEju7u7XLZNaB4Abe+2115SUlHTde3H3798vwzBUvnx5FS5c2GHbu3evw4I39evX15YtW3T16lVFRUWpaNGiqlmzpsLDw82R3fXr19/wFh8gFblmTa7RwQVwS/7+/ipatKh27tx503I7d+5UsWLFHK4WAq7GMAz16dNHixYt0k8//aTQ0NBbvmf79u2SZI4S1alTR7t27XL4krxixQr5+fmpcuXKZplVq1Y51LNixQrVqVNHkuTp6alatWo5lLHb7Vq1apVZBsCNlSlTRl27dtW0adN08uRJh2N2u102m03Lly/XihUr0mxTp041y9arV0+JiYmKjo5WVFSU2ZGtX7++oqKi9Pvvv+vs2bN0cJFjuVquMUUZQLo8/PDD+vjjj7V+/XrVq1cvzfGoqCgdPnxY//d//3fd95cqVUpSylXxMmXKmPvPnj2rCxcuZE2j4dKsnsqVXr1799a8efP0zTffyNfX17wvyN/fXz4+Pjp48KDmzZunhx56SAULFtTOnTvVr18/NWjQQNWrV5ckNW/eXJUrV9bjjz+u8ePH69SpU3rttdfUu3dv8wr7M88
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1000x400 with 4 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False)\n",
|
|||
|
"\n",
|
|||
|
"for index in range(0, len(optimized_metrics)):\n",
|
|||
|
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[\"Below Average\", \"Above Average\"]\n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
" disp.ax_.set_title(optimized_metrics.index[index]) \n",
|
|||
|
"\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"В желтом квадрате мы видим значение 28511, что обозначает количество правильно классифицированных объектов, отнесенных к классу \"Below Average\". Это свидетельствует о том, что модель успешно идентифицирует объекты этого класса, минимизируя количество ложных положительных срабатываний.\n",
|
|||
|
"\n",
|
|||
|
"В зеленом квадрате значение 3952 указывает на количество правильно классифицированных объектов, отнесенных к классу \"Above Average\". Это также является показателем высокой точности модели в определении объектов данного класса."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Определение достижимого уровня качества модели для второй задачи (задача регрессии)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Загрузка данных и создание целевой переменной"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Среднее значение поля 'price': 1991.6325132793531\n",
|
|||
|
" category sub_category \\\n",
|
|||
|
"0 Groceries Fruits & Vegetables \n",
|
|||
|
"1 Groceries Fruits & Vegetables \n",
|
|||
|
"2 Groceries Fruits & Vegetables \n",
|
|||
|
"3 Groceries Fruits & Vegetables \n",
|
|||
|
"4 Groceries Fruits & Vegetables \n",
|
|||
|
"\n",
|
|||
|
" href \\\n",
|
|||
|
"0 https://www.jiomart.com/c/groceries/fruits-veg... \n",
|
|||
|
"1 https://www.jiomart.com/c/groceries/fruits-veg... \n",
|
|||
|
"2 https://www.jiomart.com/c/groceries/fruits-veg... \n",
|
|||
|
"3 https://www.jiomart.com/c/groceries/fruits-veg... \n",
|
|||
|
"4 https://www.jiomart.com/c/groceries/fruits-veg... \n",
|
|||
|
"\n",
|
|||
|
" items price \\\n",
|
|||
|
"0 Fresh Dates (Pack) (Approx 450 g - 500 g) 109.0 \n",
|
|||
|
"1 Tender Coconut Cling Wrapped (1 pc) (Approx 90... 49.0 \n",
|
|||
|
"2 Mosambi 1 kg 69.0 \n",
|
|||
|
"3 Orange Imported 1 kg 125.0 \n",
|
|||
|
"4 Banana Robusta 6 pcs (Box) (Approx 800 g - 110... 44.0 \n",
|
|||
|
"\n",
|
|||
|
" above_average_price \n",
|
|||
|
"0 0 \n",
|
|||
|
"1 0 \n",
|
|||
|
"2 0 \n",
|
|||
|
"3 0 \n",
|
|||
|
"4 0 \n",
|
|||
|
"Статистическое описание DataFrame:\n",
|
|||
|
" price above_average_price\n",
|
|||
|
"count 1.622820e+05 162313.000000\n",
|
|||
|
"mean 1.991633e+03 0.121734\n",
|
|||
|
"std 1.593479e+04 0.326979\n",
|
|||
|
"min 5.000000e+00 0.000000\n",
|
|||
|
"25% 2.840000e+02 0.000000\n",
|
|||
|
"50% 4.990000e+02 0.000000\n",
|
|||
|
"75% 9.990000e+02 0.000000\n",
|
|||
|
"max 3.900000e+06 1.000000\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"from sklearn import set_config\n",
|
|||
|
"\n",
|
|||
|
"set_config(transform_output=\"pandas\")\n",
|
|||
|
"\n",
|
|||
|
"# Загрузка данных\n",
|
|||
|
"df = pd.read_csv(\"..//static//csv//jio_mart_items.csv\")\n",
|
|||
|
"\n",
|
|||
|
"# Опция для настройки генерации случайных чисел \n",
|
|||
|
"random_state = 42\n",
|
|||
|
"\n",
|
|||
|
"# Вычисление среднего значения поля \"price\"\n",
|
|||
|
"average_price = df['price'].mean()\n",
|
|||
|
"print(f\"Среднее значение поля 'price': {average_price}\")\n",
|
|||
|
"\n",
|
|||
|
"# Создание новой колонки, указывающей, выше или ниже среднего значение цены\n",
|
|||
|
"df['above_average_price'] = (df['price'] > average_price).astype(int)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод DataFrame с новой колонкой\n",
|
|||
|
"print(df.head())\n",
|
|||
|
"\n",
|
|||
|
"# Примерный анализ данных\n",
|
|||
|
"print(\"Статистическое описание DataFrame:\")\n",
|
|||
|
"print(df.describe())"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи регрессии\n",
|
|||
|
"\n",
|
|||
|
"Целевой признак -- above_average_price"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 22,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>category</th>\n",
|
|||
|
" <th>sub_category</th>\n",
|
|||
|
" <th>href</th>\n",
|
|||
|
" <th>items</th>\n",
|
|||
|
" <th>price</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>38475</th>\n",
|
|||
|
" <td>Groceries</td>\n",
|
|||
|
" <td>Mom & Baby Care</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/groceries/mom-baby-c...</td>\n",
|
|||
|
" <td>Halo Nation Green Plastic Wobbling Roly Poly T...</td>\n",
|
|||
|
" <td>529.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3550</th>\n",
|
|||
|
" <td>Groceries</td>\n",
|
|||
|
" <td>Staples</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/groceries/staples/ri...</td>\n",
|
|||
|
" <td>OrgaSatva Organic Sona Masuri Rice (White) 1 kg</td>\n",
|
|||
|
" <td>420.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>145206</th>\n",
|
|||
|
" <td>Electronics</td>\n",
|
|||
|
" <td>Accessories</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/electronics/accessor...</td>\n",
|
|||
|
" <td>itek 10000 mAh Power Bank, RBB013_BK</td>\n",
|
|||
|
" <td>1099.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>151588</th>\n",
|
|||
|
" <td>Beauty</td>\n",
|
|||
|
" <td>Make-Up</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/beauty/make-up/lips/...</td>\n",
|
|||
|
" <td>Fashion Colour Satin Smooth Lip Definer, 14 Ab...</td>\n",
|
|||
|
" <td>356.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>28297</th>\n",
|
|||
|
" <td>Groceries</td>\n",
|
|||
|
" <td>Home Care</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/groceries/home-care/...</td>\n",
|
|||
|
" <td>My Home Lavender Trail Air Freshener Block 50 ...</td>\n",
|
|||
|
" <td>65.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>119879</th>\n",
|
|||
|
" <td>Fashion</td>\n",
|
|||
|
" <td>Women</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/fashion/women/bags-b...</td>\n",
|
|||
|
" <td>Trysco Women Genuine Leather Yellow Belt</td>\n",
|
|||
|
" <td>599.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>103694</th>\n",
|
|||
|
" <td>Home & Kitchen</td>\n",
|
|||
|
" <td>Pooja Needs</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/groceries/home-kitch...</td>\n",
|
|||
|
" <td>Majmua Attar Made Pure and Natural Exclusive I...</td>\n",
|
|||
|
" <td>599.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>131932</th>\n",
|
|||
|
" <td>Fashion</td>\n",
|
|||
|
" <td>Girls</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/fashion/girls/watche...</td>\n",
|
|||
|
" <td>Mikado Analog Blue Watch For Girls ,Pack Of 2</td>\n",
|
|||
|
" <td>249.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>146867</th>\n",
|
|||
|
" <td>Electronics</td>\n",
|
|||
|
" <td>Accessories</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/electronics/accessor...</td>\n",
|
|||
|
" <td>Reconnect RACMB1001 Car Mount</td>\n",
|
|||
|
" <td>100.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>121958</th>\n",
|
|||
|
" <td>Fashion</td>\n",
|
|||
|
" <td>Women</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/fashion/women/fashio...</td>\n",
|
|||
|
" <td>Traditional Long Earring Zinc Jhumki Earring (...</td>\n",
|
|||
|
" <td>129.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>129850 rows × 5 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" category sub_category \\\n",
|
|||
|
"38475 Groceries Mom & Baby Care \n",
|
|||
|
"3550 Groceries Staples \n",
|
|||
|
"145206 Electronics Accessories \n",
|
|||
|
"151588 Beauty Make-Up \n",
|
|||
|
"28297 Groceries Home Care \n",
|
|||
|
"... ... ... \n",
|
|||
|
"119879 Fashion Women \n",
|
|||
|
"103694 Home & Kitchen Pooja Needs \n",
|
|||
|
"131932 Fashion Girls \n",
|
|||
|
"146867 Electronics Accessories \n",
|
|||
|
"121958 Fashion Women \n",
|
|||
|
"\n",
|
|||
|
" href \\\n",
|
|||
|
"38475 https://www.jiomart.com/c/groceries/mom-baby-c... \n",
|
|||
|
"3550 https://www.jiomart.com/c/groceries/staples/ri... \n",
|
|||
|
"145206 https://www.jiomart.com/c/electronics/accessor... \n",
|
|||
|
"151588 https://www.jiomart.com/c/beauty/make-up/lips/... \n",
|
|||
|
"28297 https://www.jiomart.com/c/groceries/home-care/... \n",
|
|||
|
"... ... \n",
|
|||
|
"119879 https://www.jiomart.com/c/fashion/women/bags-b... \n",
|
|||
|
"103694 https://www.jiomart.com/c/groceries/home-kitch... \n",
|
|||
|
"131932 https://www.jiomart.com/c/fashion/girls/watche... \n",
|
|||
|
"146867 https://www.jiomart.com/c/electronics/accessor... \n",
|
|||
|
"121958 https://www.jiomart.com/c/fashion/women/fashio... \n",
|
|||
|
"\n",
|
|||
|
" items price \n",
|
|||
|
"38475 Halo Nation Green Plastic Wobbling Roly Poly T... 529.0 \n",
|
|||
|
"3550 OrgaSatva Organic Sona Masuri Rice (White) 1 kg 420.0 \n",
|
|||
|
"145206 itek 10000 mAh Power Bank, RBB013_BK 1099.0 \n",
|
|||
|
"151588 Fashion Colour Satin Smooth Lip Definer, 14 Ab... 356.0 \n",
|
|||
|
"28297 My Home Lavender Trail Air Freshener Block 50 ... 65.0 \n",
|
|||
|
"... ... ... \n",
|
|||
|
"119879 Trysco Women Genuine Leather Yellow Belt 599.0 \n",
|
|||
|
"103694 Majmua Attar Made Pure and Natural Exclusive I... 599.0 \n",
|
|||
|
"131932 Mikado Analog Blue Watch For Girls ,Pack Of 2 249.0 \n",
|
|||
|
"146867 Reconnect RACMB1001 Car Mount 100.0 \n",
|
|||
|
"121958 Traditional Long Earring Zinc Jhumki Earring (... 129.0 \n",
|
|||
|
"\n",
|
|||
|
"[129850 rows x 5 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>above_average_price</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>38475</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3550</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>145206</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>151588</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>28297</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>119879</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>103694</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>131932</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>146867</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>121958</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>129850 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" above_average_price\n",
|
|||
|
"38475 0\n",
|
|||
|
"3550 0\n",
|
|||
|
"145206 0\n",
|
|||
|
"151588 0\n",
|
|||
|
"28297 0\n",
|
|||
|
"... ...\n",
|
|||
|
"119879 0\n",
|
|||
|
"103694 0\n",
|
|||
|
"131932 0\n",
|
|||
|
"146867 0\n",
|
|||
|
"121958 0\n",
|
|||
|
"\n",
|
|||
|
"[129850 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>category</th>\n",
|
|||
|
" <th>sub_category</th>\n",
|
|||
|
" <th>href</th>\n",
|
|||
|
" <th>items</th>\n",
|
|||
|
" <th>price</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>52893</th>\n",
|
|||
|
" <td>Home & Kitchen</td>\n",
|
|||
|
" <td>Dining</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/groceries/home-kitch...</td>\n",
|
|||
|
" <td>CRAFTYKART Brown Shesham Wood Serving Tray ?35...</td>\n",
|
|||
|
" <td>699.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>78308</th>\n",
|
|||
|
" <td>Home & Kitchen</td>\n",
|
|||
|
" <td>Toys, Games & Fitness</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/groceries/home-kitch...</td>\n",
|
|||
|
" <td>Magicwand Red ABS Plastic 4Wd 360 Degree Twist...</td>\n",
|
|||
|
" <td>7999.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>159477</th>\n",
|
|||
|
" <td>Beauty</td>\n",
|
|||
|
" <td>Fragrances</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/beauty/fragrances/wo...</td>\n",
|
|||
|
" <td>Ajmal Senora EDP Floral Spicy Perfume And Sacr...</td>\n",
|
|||
|
" <td>1295.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>74384</th>\n",
|
|||
|
" <td>Home & Kitchen</td>\n",
|
|||
|
" <td>Toys, Games & Fitness</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/groceries/home-kitch...</td>\n",
|
|||
|
" <td>Frantic Ultra Soft Stuffed Lovable Spongy Huga...</td>\n",
|
|||
|
" <td>369.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>93511</th>\n",
|
|||
|
" <td>Home & Kitchen</td>\n",
|
|||
|
" <td>Bags & Travel Luggage</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/groceries/home-kitch...</td>\n",
|
|||
|
" <td>DE VAGABOND Orange Black Polyester Travel Duff...</td>\n",
|
|||
|
" <td>749.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>117300</th>\n",
|
|||
|
" <td>Fashion</td>\n",
|
|||
|
" <td>Women</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/fashion/women/wester...</td>\n",
|
|||
|
" <td>Tees World Women Grey Regular Fit Round Neck P...</td>\n",
|
|||
|
" <td>999.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>24023</th>\n",
|
|||
|
" <td>Groceries</td>\n",
|
|||
|
" <td>Personal Care</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/groceries/personal-c...</td>\n",
|
|||
|
" <td>Vetoni Fruit Punch Lather Shaving Cream for Me...</td>\n",
|
|||
|
" <td>300.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>129165</th>\n",
|
|||
|
" <td>Fashion</td>\n",
|
|||
|
" <td>Girls</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/fashion/girls/wester...</td>\n",
|
|||
|
" <td>IndiWeaves Girls Printed Cotton Half Sleeves T...</td>\n",
|
|||
|
" <td>799.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>71336</th>\n",
|
|||
|
" <td>Home & Kitchen</td>\n",
|
|||
|
" <td>Furniture</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/groceries/home-kitch...</td>\n",
|
|||
|
" <td>EVEREST DRAWER V</td>\n",
|
|||
|
" <td>2081.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>110968</th>\n",
|
|||
|
" <td>Fashion</td>\n",
|
|||
|
" <td>Men</td>\n",
|
|||
|
" <td>https://www.jiomart.com/c/fashion/men/footwear...</td>\n",
|
|||
|
" <td>Birde Sports Shoes</td>\n",
|
|||
|
" <td>399.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>32463 rows × 5 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" category sub_category \\\n",
|
|||
|
"52893 Home & Kitchen Dining \n",
|
|||
|
"78308 Home & Kitchen Toys, Games & Fitness \n",
|
|||
|
"159477 Beauty Fragrances \n",
|
|||
|
"74384 Home & Kitchen Toys, Games & Fitness \n",
|
|||
|
"93511 Home & Kitchen Bags & Travel Luggage \n",
|
|||
|
"... ... ... \n",
|
|||
|
"117300 Fashion Women \n",
|
|||
|
"24023 Groceries Personal Care \n",
|
|||
|
"129165 Fashion Girls \n",
|
|||
|
"71336 Home & Kitchen Furniture \n",
|
|||
|
"110968 Fashion Men \n",
|
|||
|
"\n",
|
|||
|
" href \\\n",
|
|||
|
"52893 https://www.jiomart.com/c/groceries/home-kitch... \n",
|
|||
|
"78308 https://www.jiomart.com/c/groceries/home-kitch... \n",
|
|||
|
"159477 https://www.jiomart.com/c/beauty/fragrances/wo... \n",
|
|||
|
"74384 https://www.jiomart.com/c/groceries/home-kitch... \n",
|
|||
|
"93511 https://www.jiomart.com/c/groceries/home-kitch... \n",
|
|||
|
"... ... \n",
|
|||
|
"117300 https://www.jiomart.com/c/fashion/women/wester... \n",
|
|||
|
"24023 https://www.jiomart.com/c/groceries/personal-c... \n",
|
|||
|
"129165 https://www.jiomart.com/c/fashion/girls/wester... \n",
|
|||
|
"71336 https://www.jiomart.com/c/groceries/home-kitch... \n",
|
|||
|
"110968 https://www.jiomart.com/c/fashion/men/footwear... \n",
|
|||
|
"\n",
|
|||
|
" items price \n",
|
|||
|
"52893 CRAFTYKART Brown Shesham Wood Serving Tray ?35... 699.0 \n",
|
|||
|
"78308 Magicwand Red ABS Plastic 4Wd 360 Degree Twist... 7999.0 \n",
|
|||
|
"159477 Ajmal Senora EDP Floral Spicy Perfume And Sacr... 1295.0 \n",
|
|||
|
"74384 Frantic Ultra Soft Stuffed Lovable Spongy Huga... 369.0 \n",
|
|||
|
"93511 DE VAGABOND Orange Black Polyester Travel Duff... 749.0 \n",
|
|||
|
"... ... ... \n",
|
|||
|
"117300 Tees World Women Grey Regular Fit Round Neck P... 999.0 \n",
|
|||
|
"24023 Vetoni Fruit Punch Lather Shaving Cream for Me... 300.0 \n",
|
|||
|
"129165 IndiWeaves Girls Printed Cotton Half Sleeves T... 799.0 \n",
|
|||
|
"71336 EVEREST DRAWER V 2081.0 \n",
|
|||
|
"110968 Birde Sports Shoes 399.0 \n",
|
|||
|
"\n",
|
|||
|
"[32463 rows x 5 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>above_average_price</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>52893</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>78308</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>159477</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>74384</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>93511</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>117300</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>24023</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>129165</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>71336</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>110968</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>32463 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" above_average_price\n",
|
|||
|
"52893 0\n",
|
|||
|
"78308 1\n",
|
|||
|
"159477 0\n",
|
|||
|
"74384 0\n",
|
|||
|
"93511 0\n",
|
|||
|
"... ...\n",
|
|||
|
"117300 0\n",
|
|||
|
"24023 0\n",
|
|||
|
"129165 0\n",
|
|||
|
"71336 1\n",
|
|||
|
"110968 0\n",
|
|||
|
"\n",
|
|||
|
"[32463 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from typing import Tuple\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"from pandas import DataFrame\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"\n",
|
|||
|
"def split_into_train_test(\n",
|
|||
|
" df_input: DataFrame,\n",
|
|||
|
" target_colname: str = \"above_average_price\", \n",
|
|||
|
" frac_train: float = 0.8,\n",
|
|||
|
" random_state: int = None,\n",
|
|||
|
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
|||
|
" \n",
|
|||
|
" if not (0 < frac_train < 1):\n",
|
|||
|
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
|
|||
|
" \n",
|
|||
|
" # Проверка наличия целевого признака\n",
|
|||
|
" if target_colname not in df_input.columns:\n",
|
|||
|
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
|
|||
|
" \n",
|
|||
|
" # Разделяем данные на признаки и целевую переменную\n",
|
|||
|
" X = df_input.drop(columns=[target_colname]) # Признаки\n",
|
|||
|
" y = df_input[[target_colname]] # Целевая переменная\n",
|
|||
|
"\n",
|
|||
|
" # Разделяем данные на обучающую и тестовую выборки\n",
|
|||
|
" X_train, X_test, y_train, y_test = train_test_split(\n",
|
|||
|
" X, y,\n",
|
|||
|
" test_size=(1.0 - frac_train),\n",
|
|||
|
" random_state=random_state\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" return X_train, X_test, y_train, y_test\n",
|
|||
|
"\n",
|
|||
|
"# Применение функции для разделения данных\n",
|
|||
|
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
|
|||
|
" df, \n",
|
|||
|
" target_colname=\"above_average_price\", \n",
|
|||
|
" frac_train=0.8, \n",
|
|||
|
" random_state=42 \n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Для отображения результатов\n",
|
|||
|
"display(\"X_train\", X_train)\n",
|
|||
|
"display(\"y_train\", y_train)\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_test\", X_test)\n",
|
|||
|
"display(\"y_test\", y_test)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Формирование конвейера для решения задачи регрессии"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 23,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" price category_Electronics category_Fashion category_Groceries \\\n",
|
|||
|
"0 -0.118140 0.0 0.0 1.0 \n",
|
|||
|
"1 -0.121905 0.0 0.0 1.0 \n",
|
|||
|
"2 -0.120650 0.0 0.0 1.0 \n",
|
|||
|
"3 -0.117136 0.0 0.0 1.0 \n",
|
|||
|
"4 -0.122219 0.0 0.0 1.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"162308 -0.020231 0.0 0.0 0.0 \n",
|
|||
|
"162309 -0.037679 0.0 0.0 0.0 \n",
|
|||
|
"162310 -0.072637 0.0 0.0 0.0 \n",
|
|||
|
"162311 0.017865 0.0 0.0 0.0 \n",
|
|||
|
"162312 -0.072637 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" category_Home & Kitchen category_Jewellery sub_category_Apparel \\\n",
|
|||
|
"0 0.0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"162308 0.0 1.0 0.0 \n",
|
|||
|
"162309 0.0 1.0 0.0 \n",
|
|||
|
"162310 0.0 1.0 0.0 \n",
|
|||
|
"162311 0.0 1.0 0.0 \n",
|
|||
|
"162312 0.0 1.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Auto Care sub_category_Ayush \\\n",
|
|||
|
"0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 \n",
|
|||
|
"... ... ... \n",
|
|||
|
"162308 0.0 0.0 \n",
|
|||
|
"162309 0.0 0.0 \n",
|
|||
|
"162310 0.0 0.0 \n",
|
|||
|
"162311 0.0 0.0 \n",
|
|||
|
"162312 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Bags & Travel Luggage ... \\\n",
|
|||
|
"0 0.0 ... \n",
|
|||
|
"1 0.0 ... \n",
|
|||
|
"2 0.0 ... \n",
|
|||
|
"3 0.0 ... \n",
|
|||
|
"4 0.0 ... \n",
|
|||
|
"... ... ... \n",
|
|||
|
"162308 0.0 ... \n",
|
|||
|
"162309 0.0 ... \n",
|
|||
|
"162310 0.0 ... \n",
|
|||
|
"162311 0.0 ... \n",
|
|||
|
"162312 0.0 ... \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Snacks & Branded Foods sub_category_Staples \\\n",
|
|||
|
"0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 \n",
|
|||
|
"... ... ... \n",
|
|||
|
"162308 0.0 0.0 \n",
|
|||
|
"162309 0.0 0.0 \n",
|
|||
|
"162310 0.0 0.0 \n",
|
|||
|
"162311 0.0 0.0 \n",
|
|||
|
"162312 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Stationery sub_category_TV & Speaker sub_category_Tech \\\n",
|
|||
|
"0 0.0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"162308 0.0 0.0 0.0 \n",
|
|||
|
"162309 0.0 0.0 0.0 \n",
|
|||
|
"162310 0.0 0.0 0.0 \n",
|
|||
|
"162311 0.0 0.0 0.0 \n",
|
|||
|
"162312 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Tools & Appliances sub_category_Toys, Games & Fitness \\\n",
|
|||
|
"0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 \n",
|
|||
|
"... ... ... \n",
|
|||
|
"162308 0.0 0.0 \n",
|
|||
|
"162309 0.0 0.0 \n",
|
|||
|
"162310 0.0 0.0 \n",
|
|||
|
"162311 0.0 0.0 \n",
|
|||
|
"162312 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" sub_category_Treatments sub_category_Wellness sub_category_Women \n",
|
|||
|
"0 0.0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"162308 0.0 0.0 0.0 \n",
|
|||
|
"162309 0.0 0.0 0.0 \n",
|
|||
|
"162310 0.0 0.0 0.0 \n",
|
|||
|
"162311 0.0 0.0 0.0 \n",
|
|||
|
"162312 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
"[162313 rows x 77 columns]\n",
|
|||
|
"(162313, 77)\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn.base import BaseEstimator, TransformerMixin\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.preprocessing import OneHotEncoder\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor \n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"from sklearn.pipeline import make_pipeline\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"\n",
|
|||
|
"class JioMartFeatures(BaseEstimator, TransformerMixin): \n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" pass\n",
|
|||
|
"\n",
|
|||
|
" def fit(self, X, y=None):\n",
|
|||
|
" return self\n",
|
|||
|
"\n",
|
|||
|
" def transform(self, X, y=None):\n",
|
|||
|
" if 'category' in X.columns:\n",
|
|||
|
" X[\"Price_per_Category\"] = X[\"price\"] / X[\"category\"].nunique()\n",
|
|||
|
" return X\n",
|
|||
|
"\n",
|
|||
|
" def get_feature_names_out(self, features_in):\n",
|
|||
|
" return np.append(features_in, [\"Price_per_Category\"], axis=0) \n",
|
|||
|
"\n",
|
|||
|
"# Определите признаки для вашей задачи\n",
|
|||
|
"columns_to_drop = [\"href\", \"items\"] \n",
|
|||
|
"num_columns = [\"price\"] \n",
|
|||
|
"cat_columns = [\"category\", \"sub_category\"]\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование числовых признаков\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование категориальных признаков\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Формирование конвейера\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\" \n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"drop_columns = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\",\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Окончательный конвейер\n",
|
|||
|
"pipeline_end = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" (\"drop_columns\", drop_columns),\n",
|
|||
|
" (\"custom_features\", JioMartFeatures()), # Добавляем custom_features\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Загрузка данных\n",
|
|||
|
"df = pd.read_csv(\"..//static//csv//jio_mart_items.csv\")\n",
|
|||
|
"\n",
|
|||
|
"# Создаем целевой признак\n",
|
|||
|
"average_price = df['price'].mean()\n",
|
|||
|
"df['above_average_price'] = (df['price'] > average_price).astype(int)\n",
|
|||
|
"\n",
|
|||
|
"# Подготовка данных\n",
|
|||
|
"X = df.drop('above_average_price', axis=1)\n",
|
|||
|
"y = df['above_average_price'].values.ravel()\n",
|
|||
|
"\n",
|
|||
|
"# Проверка наличия столбцов перед применением конвейера\n",
|
|||
|
"required_columns = set(num_columns + cat_columns + columns_to_drop)\n",
|
|||
|
"missing_columns = required_columns - set(X.columns)\n",
|
|||
|
"if missing_columns:\n",
|
|||
|
" raise KeyError(f\"Missing columns: {missing_columns}\")\n",
|
|||
|
"\n",
|
|||
|
"# Применение конвейера\n",
|
|||
|
"X_processed = pipeline_end.fit_transform(X)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод\n",
|
|||
|
"print(X_processed)\n",
|
|||
|
"print(X_processed.shape)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Формирование набора моделей для регрессии"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 28,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Random Forest: Mean Score = 0.9897752006377067, Standard Deviation = 0.012886225390386691\n",
|
|||
|
"Linear Regression: Mean Score = -1.439679711903671e+21, Standard Deviation = 1.9848730981021744e+21\n",
|
|||
|
"Gradient Boosting: Mean Score = 0.990533312551943, Standard Deviation = 0.01338791677558754\n",
|
|||
|
"Support Vector Regression: Mean Score = 0.6408179773886161, Standard Deviation = 0.045968161125540155\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/ensemble/_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/ensemble/_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/ensemble/_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/utils/validation.py:1339: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True)\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/utils/validation.py:1339: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True)\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/utils/validation.py:1339: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True)\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.linear_model import LinearRegression\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor\n",
|
|||
|
"from sklearn.model_selection import cross_val_score\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.ensemble import GradientBoostingRegressor\n",
|
|||
|
"from sklearn.svm import SVR\n",
|
|||
|
"\n",
|
|||
|
"def train_multiple_models(X, y, models, cv=3):\n",
|
|||
|
" results = {}\n",
|
|||
|
" for model_name, model in models.items():\n",
|
|||
|
" # Создаем конвейер для каждой модели\n",
|
|||
|
" model_pipeline = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" (\"drop_columns\", drop_columns),\n",
|
|||
|
" (\"model\", model) # Используем текущую модель\n",
|
|||
|
" ]\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" # Обучаем модель и вычисляем кросс-валидацию\n",
|
|||
|
" scores = cross_val_score(model_pipeline, X, y, cv=cv, n_jobs=-1) # Используем все ядра процессора\n",
|
|||
|
" results[model_name] = {\n",
|
|||
|
" \"mean_score\": scores.mean(),\n",
|
|||
|
" \"std_dev\": scores.std()\n",
|
|||
|
" }\n",
|
|||
|
" \n",
|
|||
|
" return results\n",
|
|||
|
"\n",
|
|||
|
"# Определение моделей\n",
|
|||
|
"models = {\n",
|
|||
|
" \"Random Forest\": RandomForestRegressor(n_estimators=10), # Уменьшаем количество деревьев\n",
|
|||
|
" \"Linear Regression\": LinearRegression(),\n",
|
|||
|
" \"Gradient Boosting\": GradientBoostingRegressor(),\n",
|
|||
|
" \"Support Vector Regression\": SVR()\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Используем подвыборку данных\n",
|
|||
|
"sample_size = 1000 # Уменьшаем количество данных для обучения\n",
|
|||
|
"X_train_sample = X_train.sample(n=sample_size, random_state=42)\n",
|
|||
|
"y_train_sample = y_train.loc[X_train_sample.index] # Используем loc для индексации Series\n",
|
|||
|
"\n",
|
|||
|
"# Обучение моделей и вывод результатов\n",
|
|||
|
"results = train_multiple_models(X_train_sample, y_train_sample, models, cv=3) # Уменьшаем количество фолдов\n",
|
|||
|
"\n",
|
|||
|
"# Вывод результатов\n",
|
|||
|
"for model_name, scores in results.items():\n",
|
|||
|
" print(f\"{model_name}: Mean Score = {scores['mean_score']}, Standard Deviation = {scores['std_dev']}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Модель: Random Forest\n",
|
|||
|
"- **Mean Score**: 0.9897752006377067\n",
|
|||
|
"- **Standard Deviation**: 0.012886225390386691\n",
|
|||
|
"**Описание**:\n",
|
|||
|
"- Random Forest показала очень высокое среднее значение, близкое к 1, что указывает на ее высокую точность в предсказании. Стандартное отклонение также относительно низкое, что говорит о стабильности модели.\n",
|
|||
|
"\n",
|
|||
|
"#### Модель: Linear Regression\n",
|
|||
|
"- **Mean Score**: -1.439679711903671e+21\n",
|
|||
|
"- **Standard Deviation**: 1.9848730981021744e+21\n",
|
|||
|
"**Описание**:\n",
|
|||
|
"- Линейная регрессия показала очень низкое среднее значение с огромным отрицательным числом, что указывает на ее неэффективность в данной задаче. Стандартное отклонение также очень высокое, что говорит о нестабильности модели.\n",
|
|||
|
"\n",
|
|||
|
"#### Модель: Gradient Boosting\n",
|
|||
|
"- **Mean Score**: 0.990533312551943\n",
|
|||
|
"- **Standard Deviation**: 0.01338791677558754\n",
|
|||
|
"**Описание**:\n",
|
|||
|
"- Gradient Boosting показала практически идеальное среднее значение, близкое к 1, что указывает на ее высокую точность в предсказании. Стандартное отклонение относительно низкое, что говорит о стабильности модели.\n",
|
|||
|
"\n",
|
|||
|
"#### Модель: Support Vector Regression\n",
|
|||
|
"- **Mean Score**: 0.6408179773886161\n",
|
|||
|
"- **Standard Deviation**: 0.045968161125540155\n",
|
|||
|
"**Описание**:\n",
|
|||
|
"- Support Vector Regression показала среднее значение около 0.64, что указывает на ее умеренную точность в предсказании. Стандартное отклонение относительно низкое, что говорит о стабильности модели, но она все же уступает Random Forest и Gradient Boosting.\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"1. **Random Forest и Gradient Boosting** демонстрируют высокую точность и стабильность, что делает их наиболее подходящими моделями для данной задачи регрессии.\n",
|
|||
|
"2. **Linear Regression** неэффективна и нестабильна, что указывает на необходимость ее замены на более подходящую модель.\n",
|
|||
|
"3. **Support Vector Regression** показывает умеренную точность и стабильность, но уступает Random Forest и Gradient Boosting в эффективности."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обучение моделей на обучающем наборе данных и оценка на тестовом для регрессии"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 29,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: logistic\n",
|
|||
|
"MSE (train): 0.00954948016942626\n",
|
|||
|
"MSE (test): 0.009857376089702122\n",
|
|||
|
"MAE (train): 0.00954948016942626\n",
|
|||
|
"MAE (test): 0.009857376089702122\n",
|
|||
|
"R2 (train): 0.9105001240660583\n",
|
|||
|
"R2 (test): 0.9085410706513222\n",
|
|||
|
"STD (train): 0.09733042790899017\n",
|
|||
|
"STD (test): 0.09886474010790139\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: ridge\n",
|
|||
|
"MSE (train): 0.016395841355410088\n",
|
|||
|
"MSE (test): 0.016418692049410096\n",
|
|||
|
"MAE (train): 0.016395841355410088\n",
|
|||
|
"MAE (test): 0.016418692049410096\n",
|
|||
|
"R2 (train): 0.8463344872069661\n",
|
|||
|
"R2 (test): 0.8476637208036084\n",
|
|||
|
"STD (train): 0.12699418323145514\n",
|
|||
|
"STD (test): 0.1270791824052891\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: decision_tree\n",
|
|||
|
"MSE (train): 0.0\n",
|
|||
|
"MSE (test): 0.0\n",
|
|||
|
"MAE (train): 0.0\n",
|
|||
|
"MAE (test): 0.0\n",
|
|||
|
"R2 (train): 1.0\n",
|
|||
|
"R2 (test): 1.0\n",
|
|||
|
"STD (train): 0.0\n",
|
|||
|
"STD (test): 0.0\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: knn\n",
|
|||
|
"MSE (train): 0.00041586445899114365\n",
|
|||
|
"MSE (test): 0.0004928688044851062\n",
|
|||
|
"MAE (train): 0.00041586445899114365\n",
|
|||
|
"MAE (test): 0.0004928688044851062\n",
|
|||
|
"R2 (train): 0.9961024247577155\n",
|
|||
|
"R2 (test): 0.9954270535325661\n",
|
|||
|
"STD (train): 0.020392609645475155\n",
|
|||
|
"STD (test): 0.022199280947150013\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: naive_bayes\n",
|
|||
|
"MSE (train): 0.6530150173276857\n",
|
|||
|
"MSE (test): 0.6539752949511752\n",
|
|||
|
"MAE (train): 0.6530150173276857\n",
|
|||
|
"MAE (test): 0.6539752949511752\n",
|
|||
|
"R2 (train): -5.1202036128569794\n",
|
|||
|
"R2 (test): -5.0677283439763485\n",
|
|||
|
"STD (train): 0.4850279840944924\n",
|
|||
|
"STD (test): 0.4855381725252704\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: gradient_boosting\n",
|
|||
|
"MSE (train): 0.0\n",
|
|||
|
"MSE (test): 0.0\n",
|
|||
|
"MAE (train): 0.0\n",
|
|||
|
"MAE (test): 0.0\n",
|
|||
|
"R2 (train): 1.0\n",
|
|||
|
"R2 (test): 1.0\n",
|
|||
|
"STD (train): 0.0\n",
|
|||
|
"STD (test): 0.0\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: random_forest\n",
|
|||
|
"MSE (train): 0.0\n",
|
|||
|
"MSE (test): 0.0\n",
|
|||
|
"MAE (train): 0.0\n",
|
|||
|
"MAE (test): 0.0\n",
|
|||
|
"R2 (train): 1.0\n",
|
|||
|
"R2 (test): 1.0\n",
|
|||
|
"STD (train): 0.0\n",
|
|||
|
"STD (test): 0.0\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: mlp\n",
|
|||
|
"MSE (train): 0.0009703504043126684\n",
|
|||
|
"MSE (test): 0.0010781505098111696\n",
|
|||
|
"MAE (train): 0.0009703504043126684\n",
|
|||
|
"MAE (test): 0.0010781505098111696\n",
|
|||
|
"R2 (train): 0.9909056577680027\n",
|
|||
|
"R2 (test): 0.9899966796024884\n",
|
|||
|
"STD (train): 0.031139749763093583\n",
|
|||
|
"STD (test): 0.03281946301141911\n",
|
|||
|
"----------------------------------------\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"\n",
|
|||
|
"# Проверка наличия необходимых переменных\n",
|
|||
|
"if 'class_models' not in locals():\n",
|
|||
|
" raise ValueError(\"class_models is not defined\")\n",
|
|||
|
"if 'X_train' not in locals() or 'X_test' not in locals() or 'y_train' not in locals() or 'y_test' not in locals():\n",
|
|||
|
" raise ValueError(\"Train/test data is not defined\")\n",
|
|||
|
"\n",
|
|||
|
"# Преобразуем y_train и y_test в одномерные массивы\n",
|
|||
|
"y_train = np.ravel(y_train) \n",
|
|||
|
"y_test = np.ravel(y_test) \n",
|
|||
|
"\n",
|
|||
|
"# Инициализация списка для хранения результатов\n",
|
|||
|
"results = []\n",
|
|||
|
"\n",
|
|||
|
"# Проход по моделям и оценка их качества\n",
|
|||
|
"for model_name in class_models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" \n",
|
|||
|
" # Извлечение модели из словаря\n",
|
|||
|
" model = class_models[model_name][\"model\"]\n",
|
|||
|
" \n",
|
|||
|
" # Создание пайплайна\n",
|
|||
|
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
|
|||
|
" \n",
|
|||
|
" # Обучение модели\n",
|
|||
|
" model_pipeline.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
" # Предсказание для обучающей и тестовой выборки\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train)\n",
|
|||
|
" y_test_predict = model_pipeline.predict(X_test)\n",
|
|||
|
"\n",
|
|||
|
" # Сохранение пайплайна и предсказаний\n",
|
|||
|
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
|||
|
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
|||
|
"\n",
|
|||
|
" # Вычисление метрик для регрессии\n",
|
|||
|
" class_models[model_name][\"MSE_train\"] = metrics.mean_squared_error(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"MSE_test\"] = metrics.mean_squared_error(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"MAE_train\"] = metrics.mean_absolute_error(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"MAE_test\"] = metrics.mean_absolute_error(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"R2_train\"] = metrics.r2_score(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_predict)\n",
|
|||
|
"\n",
|
|||
|
" # Дополнительные метрики\n",
|
|||
|
" class_models[model_name][\"STD_train\"] = np.std(y_train - y_train_predict)\n",
|
|||
|
" class_models[model_name][\"STD_test\"] = np.std(y_test - y_test_predict)\n",
|
|||
|
"\n",
|
|||
|
" # Вывод результатов для текущей модели\n",
|
|||
|
" print(f\"MSE (train): {class_models[model_name]['MSE_train']}\")\n",
|
|||
|
" print(f\"MSE (test): {class_models[model_name]['MSE_test']}\")\n",
|
|||
|
" print(f\"MAE (train): {class_models[model_name]['MAE_train']}\")\n",
|
|||
|
" print(f\"MAE (test): {class_models[model_name]['MAE_test']}\")\n",
|
|||
|
" print(f\"R2 (train): {class_models[model_name]['R2_train']}\")\n",
|
|||
|
" print(f\"R2 (test): {class_models[model_name]['R2_test']}\")\n",
|
|||
|
" print(f\"STD (train): {class_models[model_name]['STD_train']}\")\n",
|
|||
|
" print(f\"STD (test): {class_models[model_name]['STD_test']}\")\n",
|
|||
|
" print(\"-\" * 40) # Разделитель для разных моделей"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Пример использования обученной модели (конвейера регрессии) для предсказания"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 36,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: RandomForest\n",
|
|||
|
"MSE (train): 8419071.042944524\n",
|
|||
|
"MSE (test): 1708514.4521493362\n",
|
|||
|
"MAE (train): 11.216263715771229\n",
|
|||
|
"MAE (test): 14.19769129925748\n",
|
|||
|
"R2 (train): 0.9638189510993855\n",
|
|||
|
"R2 (test): 0.9949568688066726\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Прогнозируемая цена: 5.77\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor \n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.preprocessing import OneHotEncoder\n",
|
|||
|
"\n",
|
|||
|
"# 1. Загрузка данных\n",
|
|||
|
"data = pd.read_csv(\"..//static//csv//jio_mart_items.csv\") \n",
|
|||
|
"\n",
|
|||
|
"# 2. Подготовка данных для прогноза\n",
|
|||
|
"average_price = data['price'].mean()\n",
|
|||
|
"data['above_average_price'] = (data['price'] > average_price).astype(int) \n",
|
|||
|
"\n",
|
|||
|
"# Удаляем строки с пропущенными значениями в столбце 'price'\n",
|
|||
|
"data = data.dropna(subset=['price'])\n",
|
|||
|
"\n",
|
|||
|
"# Предикторы и целевая переменная\n",
|
|||
|
"X = data.drop('above_average_price', axis=1) # Удаляем только 'above_average_price'\n",
|
|||
|
"y = data['price']\n",
|
|||
|
"\n",
|
|||
|
"# 3. Инициализация модели и пайплайна\n",
|
|||
|
"class_models = {\n",
|
|||
|
" \"RandomForest\": {\n",
|
|||
|
" \"model\": RandomForestRegressor(n_estimators=100, random_state=42),\n",
|
|||
|
" }\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Предобработка признаков\n",
|
|||
|
"num_columns = ['price']\n",
|
|||
|
"cat_columns = ['category', 'sub_category']\n",
|
|||
|
"\n",
|
|||
|
"# Проверка наличия столбцов перед предобработкой\n",
|
|||
|
"required_columns = set(num_columns + cat_columns)\n",
|
|||
|
"missing_columns = required_columns - set(X.columns)\n",
|
|||
|
"if missing_columns:\n",
|
|||
|
" raise KeyError(f\"Missing columns: {missing_columns}\")\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование числовых признаков\n",
|
|||
|
"num_transformer = Pipeline(steps=[\n",
|
|||
|
" ('imputer', SimpleImputer(strategy='median')),\n",
|
|||
|
" ('scaler', StandardScaler())\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование категориальных признаков\n",
|
|||
|
"cat_transformer = Pipeline(steps=[\n",
|
|||
|
" ('imputer', SimpleImputer(strategy='constant', fill_value='unknown')),\n",
|
|||
|
" ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False, drop=\"first\"))\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Создание конвейера предобработки\n",
|
|||
|
"preprocessor = ColumnTransformer(\n",
|
|||
|
" transformers=[\n",
|
|||
|
" ('num', num_transformer, num_columns),\n",
|
|||
|
" ('cat', cat_transformer, cat_columns)\n",
|
|||
|
" ])\n",
|
|||
|
"\n",
|
|||
|
"# Создание конвейера модели\n",
|
|||
|
"pipeline_end = Pipeline(steps=[\n",
|
|||
|
" ('preprocessor', preprocessor),\n",
|
|||
|
" # ('model', model) # Модель добавляется в цикле\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"results = []\n",
|
|||
|
"\n",
|
|||
|
"# 4. Обучение модели и оценка\n",
|
|||
|
"for model_name in class_models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
"\n",
|
|||
|
" model = class_models[model_name][\"model\"]\n",
|
|||
|
" model_pipeline = Pipeline(steps=[\n",
|
|||
|
" ('preprocessor', preprocessor),\n",
|
|||
|
" ('model', model)\n",
|
|||
|
" ])\n",
|
|||
|
"\n",
|
|||
|
" # Разделение данных\n",
|
|||
|
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
|
|||
|
"\n",
|
|||
|
" # Обучение модели\n",
|
|||
|
" model_pipeline.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
" # Предсказание\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train)\n",
|
|||
|
" y_test_predict = model_pipeline.predict(X_test)\n",
|
|||
|
"\n",
|
|||
|
" # Сохранение результатов\n",
|
|||
|
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
|||
|
"\n",
|
|||
|
" # Вычисление метрик\n",
|
|||
|
" class_models[model_name][\"MSE_train\"] = metrics.mean_squared_error(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"MSE_test\"] = metrics.mean_squared_error(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"MAE_train\"] = metrics.mean_absolute_error(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"MAE_test\"] = metrics.mean_absolute_error(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"R2_train\"] = metrics.r2_score(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_predict)\n",
|
|||
|
"\n",
|
|||
|
" # Вывод результатов\n",
|
|||
|
" print(f\"MSE (train): {class_models[model_name]['MSE_train']}\")\n",
|
|||
|
" print(f\"MSE (test): {class_models[model_name]['MSE_test']}\")\n",
|
|||
|
" print(f\"MAE (train): {class_models[model_name]['MAE_train']}\")\n",
|
|||
|
" print(f\"MAE (test): {class_models[model_name]['MAE_test']}\")\n",
|
|||
|
" print(f\"R2 (train): {class_models[model_name]['R2_train']}\")\n",
|
|||
|
" print(f\"R2 (test): {class_models[model_name]['R2_test']}\")\n",
|
|||
|
" print(\"-\" * 40)\n",
|
|||
|
"\n",
|
|||
|
"# Прогнозирование цены для нового товара\n",
|
|||
|
"new_item_data = pd.DataFrame({\n",
|
|||
|
" 'category': ['Electronics'],\n",
|
|||
|
" 'sub_category': ['Smartphones'], \n",
|
|||
|
" 'price': [0] # Добавляем столбец 'price' с нулевым значением\n",
|
|||
|
"})\n",
|
|||
|
"\n",
|
|||
|
"predicted_price = model_pipeline.predict(new_item_data)\n",
|
|||
|
"print(f\"Прогнозируемая цена: {predicted_price[0]}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Подбор гиперпараметров методом поиска по сетке"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 37,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Fitting 3 folds for each of 36 candidates, totalling 108 fits\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=50; total time= 12.4s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=50; total time= 12.6s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=50; total time= 12.6s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=100; total time= 23.3s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=100; total time= 23.2s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=100; total time= 23.2s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=50; total time= 10.8s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=50; total time= 11.2s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=50; total time= 11.2s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=200; total time= 44.9s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=100; total time= 21.9s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=100; total time= 22.0s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=200; total time= 45.4s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=100; total time= 22.2s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=50; total time= 12.0s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=50; total time= 12.2s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=50; total time= 12.3s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=200; total time= 46.4s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=100; total time= 23.9s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=200; total time= 46.0s\n",
|
|||
|
"[CV] END .max_depth=10, min_samples_split=2, n_estimators=50; total time= 7.5s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=200; total time= 47.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=100; total time= 24.2s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=100; total time= 24.4s\n",
|
|||
|
"[CV] END .max_depth=10, min_samples_split=2, n_estimators=50; total time= 8.1s\n",
|
|||
|
"[CV] END .max_depth=10, min_samples_split=2, n_estimators=50; total time= 7.9s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=200; total time= 48.2s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=2, n_estimators=100; total time= 15.7s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=2, n_estimators=100; total time= 15.9s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=2, n_estimators=100; total time= 15.2s\n",
|
|||
|
"[CV] END .max_depth=10, min_samples_split=5, n_estimators=50; total time= 7.6s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=200; total time= 48.2s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=200; total time= 48.3s\n",
|
|||
|
"[CV] END .max_depth=10, min_samples_split=5, n_estimators=50; total time= 7.9s\n",
|
|||
|
"[CV] END .max_depth=10, min_samples_split=5, n_estimators=50; total time= 8.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=200; total time= 50.2s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=2, n_estimators=200; total time= 32.1s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=5, n_estimators=100; total time= 16.6s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=5, n_estimators=100; total time= 17.1s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=5, n_estimators=100; total time= 16.8s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=2, n_estimators=200; total time= 32.8s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=2, n_estimators=200; total time= 32.8s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=50; total time= 8.2s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=50; total time= 8.1s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=50; total time= 8.2s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=100; total time= 15.4s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=5, n_estimators=200; total time= 31.9s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=100; total time= 15.8s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=100; total time= 15.8s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=5, n_estimators=200; total time= 32.7s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=5, n_estimators=200; total time= 32.4s\n",
|
|||
|
"[CV] END .max_depth=20, min_samples_split=2, n_estimators=50; total time= 11.0s\n",
|
|||
|
"[CV] END .max_depth=20, min_samples_split=2, n_estimators=50; total time= 11.1s\n",
|
|||
|
"[CV] END .max_depth=20, min_samples_split=2, n_estimators=50; total time= 11.2s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=200; total time= 31.4s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=200; total time= 32.7s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=200; total time= 32.1s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=2, n_estimators=100; total time= 21.6s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=2, n_estimators=100; total time= 22.2s\n",
|
|||
|
"[CV] END .max_depth=20, min_samples_split=5, n_estimators=50; total time= 11.7s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=2, n_estimators=100; total time= 22.2s\n",
|
|||
|
"[CV] END .max_depth=20, min_samples_split=5, n_estimators=50; total time= 12.2s\n",
|
|||
|
"[CV] END .max_depth=20, min_samples_split=5, n_estimators=50; total time= 12.4s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=5, n_estimators=100; total time= 24.3s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=5, n_estimators=100; total time= 23.7s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=5, n_estimators=100; total time= 24.7s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=2, n_estimators=200; total time= 46.2s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=2, n_estimators=200; total time= 46.6s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=2, n_estimators=200; total time= 48.0s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=50; total time= 11.8s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=50; total time= 11.9s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=50; total time= 11.8s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=100; total time= 24.2s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=5, n_estimators=200; total time= 47.3s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=5, n_estimators=200; total time= 49.4s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=100; total time= 25.3s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=100; total time= 25.8s\n",
|
|||
|
"[CV] END .max_depth=30, min_samples_split=2, n_estimators=50; total time= 14.1s\n",
|
|||
|
"[CV] END .max_depth=30, min_samples_split=2, n_estimators=50; total time= 14.5s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=5, n_estimators=200; total time= 50.1s\n",
|
|||
|
"[CV] END .max_depth=30, min_samples_split=2, n_estimators=50; total time= 14.1s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=200; total time= 50.5s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=200; total time= 50.9s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=2, n_estimators=100; total time= 28.5s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=2, n_estimators=100; total time= 30.1s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=2, n_estimators=100; total time= 29.8s\n",
|
|||
|
"[CV] END .max_depth=30, min_samples_split=5, n_estimators=50; total time= 14.6s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=200; total time= 52.1s\n",
|
|||
|
"[CV] END .max_depth=30, min_samples_split=5, n_estimators=50; total time= 15.1s\n",
|
|||
|
"[CV] END .max_depth=30, min_samples_split=5, n_estimators=50; total time= 14.4s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=2, n_estimators=200; total time= 57.7s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=5, n_estimators=100; total time= 28.2s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=2, n_estimators=200; total time= 57.7s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=5, n_estimators=100; total time= 29.2s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=5, n_estimators=100; total time= 29.3s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=50; total time= 15.3s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=50; total time= 15.5s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=2, n_estimators=200; total time= 59.2s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=50; total time= 15.3s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=100; total time= 29.4s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=5, n_estimators=200; total time= 58.3s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=5, n_estimators=200; total time= 59.1s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=100; total time= 29.0s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=100; total time= 29.0s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=5, n_estimators=200; total time= 53.8s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=200; total time= 44.9s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=200; total time= 45.4s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=200; total time= 35.2s\n",
|
|||
|
"Лучшие параметры: {'max_depth': 10, 'min_samples_split': 10, 'n_estimators': 100}\n",
|
|||
|
"Лучший результат (MSE): 206320633.70862785\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"\n",
|
|||
|
"# Удаление строк с пропущенными значениями (если необходимо)\n",
|
|||
|
"df = df.dropna()\n",
|
|||
|
"\n",
|
|||
|
"# Создание целевой переменной (price)\n",
|
|||
|
"target = df['price']\n",
|
|||
|
"\n",
|
|||
|
"# Удаление целевой переменной из исходных данных\n",
|
|||
|
"features = df.drop(columns=['price'])\n",
|
|||
|
"\n",
|
|||
|
"# Удаление столбцов, которые не будут использоваться (например, href и items)\n",
|
|||
|
"features = features.drop(columns=['href', 'items'])\n",
|
|||
|
"\n",
|
|||
|
"# Определение столбцов для обработки\n",
|
|||
|
"num_columns = features.select_dtypes(include=['number']).columns\n",
|
|||
|
"cat_columns = features.select_dtypes(include=['object']).columns\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг числовых столбцов\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\") # Используем медиану для заполнения пропущенных значений в числовых столбцах\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг категориальных столбцов\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\") # Используем 'unknown' для заполнения пропущенных значений в категориальных столбцах\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Объединение препроцессинга\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\"\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Создание финального пайплайна\n",
|
|||
|
"pipeline_end = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Разделение данных на обучающую и тестовую выборки\n",
|
|||
|
"X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)\n",
|
|||
|
"\n",
|
|||
|
"# Применение пайплайна к данным\n",
|
|||
|
"X_train_processed = pipeline_end.fit_transform(X_train)\n",
|
|||
|
"X_test_processed = pipeline_end.transform(X_test)\n",
|
|||
|
"\n",
|
|||
|
"# 2. Создание и настройка модели случайного леса\n",
|
|||
|
"model = RandomForestRegressor()\n",
|
|||
|
"\n",
|
|||
|
"# Установка параметров для поиска по сетке\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
|
|||
|
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
|
|||
|
" 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# 3. Подбор гиперпараметров с помощью Grid Search\n",
|
|||
|
"grid_search = GridSearchCV(estimator=model, param_grid=param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели на тренировочных данных\n",
|
|||
|
"grid_search.fit(X_train_processed, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# 4. Результаты подбора гиперпараметров\n",
|
|||
|
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
|
|||
|
"print(\"Лучший результат (MSE):\", -grid_search.best_score_) # Меняем знак, так как берем отрицательное значение среднеквадратичной ошибки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обучение модели с новыми гиперпараметрами и сравнение новых и старых данных"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 44,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" category sub_category \\\n",
|
|||
|
"0 Groceries Fruits & Vegetables \n",
|
|||
|
"1 Groceries Fruits & Vegetables \n",
|
|||
|
"2 Groceries Fruits & Vegetables \n",
|
|||
|
"3 Groceries Fruits & Vegetables \n",
|
|||
|
"4 Groceries Fruits & Vegetables \n",
|
|||
|
"\n",
|
|||
|
" href \\\n",
|
|||
|
"0 https://www.jiomart.com/c/groceries/fruits-veg... \n",
|
|||
|
"1 https://www.jiomart.com/c/groceries/fruits-veg... \n",
|
|||
|
"2 https://www.jiomart.com/c/groceries/fruits-veg... \n",
|
|||
|
"3 https://www.jiomart.com/c/groceries/fruits-veg... \n",
|
|||
|
"4 https://www.jiomart.com/c/groceries/fruits-veg... \n",
|
|||
|
"\n",
|
|||
|
" items price \n",
|
|||
|
"0 Fresh Dates (Pack) (Approx 450 g - 500 g) 109.0 \n",
|
|||
|
"1 Tender Coconut Cling Wrapped (1 pc) (Approx 90... 49.0 \n",
|
|||
|
"2 Mosambi 1 kg 69.0 \n",
|
|||
|
"3 Orange Imported 1 kg 125.0 \n",
|
|||
|
"4 Banana Robusta 6 pcs (Box) (Approx 800 g - 110... 44.0 \n",
|
|||
|
"Fitting 3 folds for each of 36 candidates, totalling 108 fits\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=50; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=50; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=50; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=50; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=100; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=50; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=50; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=2, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=50; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=50; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END .max_depth=10, min_samples_split=2, n_estimators=50; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END .max_depth=10, min_samples_split=2, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=2, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=2, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END .max_depth=10, min_samples_split=5, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END .max_depth=10, min_samples_split=5, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END .max_depth=10, min_samples_split=2, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=2, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=10, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END .max_depth=10, min_samples_split=5, n_estimators=50; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=None, min_samples_split=5, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=2, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=5, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=2, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=5, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=5, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=5, n_estimators=200; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END .max_depth=20, min_samples_split=2, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=5, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=2, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END .max_depth=20, min_samples_split=2, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END .max_depth=20, min_samples_split=2, n_estimators=50; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=5, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=2, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=200; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=2, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END .max_depth=20, min_samples_split=5, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=2, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=2, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=5, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END .max_depth=20, min_samples_split=5, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END .max_depth=20, min_samples_split=5, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=2, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=10, min_samples_split=10, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=5, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=5, n_estimators=200; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=5, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=5, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END .max_depth=30, min_samples_split=2, n_estimators=50; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=2, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=2, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END .max_depth=30, min_samples_split=2, n_estimators=50; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=2, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END .max_depth=30, min_samples_split=5, n_estimators=50; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=2, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END .max_depth=30, min_samples_split=2, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END .max_depth=30, min_samples_split=5, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=5, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=5, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=2, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=5, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=20, min_samples_split=10, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=2, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=5, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=5, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=2, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=5, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=100; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=200; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=200; total time= 0.1s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=10, n_estimators=200; total time= 0.2s\n",
|
|||
|
"[CV] END .max_depth=30, min_samples_split=5, n_estimators=50; total time= 0.0s\n",
|
|||
|
"[CV] END max_depth=30, min_samples_split=5, n_estimators=100; total time= 0.0s\n",
|
|||
|
"Старые параметры: {'max_depth': 30, 'min_samples_split': 5, 'n_estimators': 50}\n",
|
|||
|
"Лучший результат (MSE) на старых параметрах: 4352.70053925649\n",
|
|||
|
"\n",
|
|||
|
"Новые параметры: {'max_depth': 10, 'min_samples_split': 10, 'n_estimators': 200}\n",
|
|||
|
"Лучший результат (MSE) на новых параметрах: 4862.953305666657\n",
|
|||
|
"Среднеквадратическая ошибка (MSE) на тестовых данных: 3485.772883899025\n",
|
|||
|
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 59.04043431326556\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAHWCAYAAAB9mLjgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAADnaklEQVR4nOzdd1wT9/8H8FcIe4OigICogOIG90DBBdqqFfeo4rats7V1tFpHK617tbZ+bbWOuhVXrZuKigsLKm6UpSDK3iO53x/87kpIgAQSLuP9fDzy0Nxd7t65XMK97/P5vE/AMAwDQgghhBBCCCFy0+M7AEIIIYQQQgjRNJRIEUIIIYQQQoiCKJEihBBCCCGEEAVRIkUIIYQQQgghCqJEihBCCCGEEEIURIkUIYQQQgghhCiIEilCCCGEEEIIURAlUoQQQgghhBCiIEqkCCGEEEIIIURBlEgRQgghhJAaOX36NCIjI7nnISEhiI6O5i8gQmoBJVKEaLiYmBhMnz4djRs3hrGxMSwtLdGtWzds2rQJ+fn5fIdHCCFEBzx48ABz5szB8+fPcfPmTcyYMQPZ2dl8h0WISgkYhmH4DoIQUj1nzpzB8OHDYWRkhPHjx6Nly5YoKirCtWvXcPToUQQFBWH79u18h0kIIUTLvXv3Dl27dsWLFy8AAIGBgTh69CjPURGiWpRIEaKhXr16hdatW8PJyQmXL1+Gg4ODxPwXL17gzJkzmDNnDk8REkII0SWFhYV4+PAhTE1N4enpyXc4hKgcde0jREOtXr0aOTk5+O2336SSKABwc3OTSKIEAgFmzpyJffv2oWnTpjA2Nka7du1w9epVidfFxcXh008/RdOmTWFiYoI6depg+PDhiI2NlVhu165dEAgE3MPU1BStWrXCjh07JJYLCgqCubm5VHxHjhyBQCBAaGioxPRbt24hICAAVlZWMDU1Rc+ePXH9+nWJZZYtWwaBQID3799LTL979y4EAgF27dolsX1XV1eJ5RISEmBiYgKBQCD1vs6ePQsfHx+YmZnBwsICH3zwgVz9/Nn9cfXqVUyfPh116tSBpaUlxo8fj/T0dKnl5dnO/fv3ERQUxHXbtLe3x6RJk5CamiozBldXV4nPhH2U3ceurq748MMPK30vsbGxEAgEWLt2rdS8li1bwtfXl3seGhoKgUCAI0eOVLi+8p/Bt99+Cz09PVy6dEliuWnTpsHQ0BBRUVGVxicQCLBs2TKJaWvWrIFAIJCIrbLXV/QoG2fZ/bBhwwY0bNgQJiYm6NmzJx4+fCi13idPnmDYsGGwtbWFsbEx2rdvj5MnT8qMISgoSOb2g4KCpJY9e/YsevbsCQsLC1haWqJDhw74888/ufm+vr5S7/v777+Hnp6exHJhYWEYPnw4XFxcYGRkBGdnZ8ybN0+qC/CyZcvQvHlzmJubw9LSEp07d0ZISIjEMvKuS5Hvv6+vL1q2bCm17Nq1a6W+q1Udx+xxya7/8ePHMDExwfjx4yWWu3btGoRCIRYsWFDhugD59oki8Z84cQIffPABHB0dYWRkhCZNmmDlypUQiUQSr5V1rLO/NdX57VL08yh/XN25c4c7VmXFaWRkhHbt2sHT01Oh7yQhmkqf7wAIIdVz6tQpNG7cGF27dpX7Nf/88w8OHjyI2bNnw8jICD///DMCAgJw+/Zt7gTgzp07uHHjBkaNGgUnJyfExsZi27Zt8PX1xaNHj2Bqaiqxzg0bNqBu3brIysrC77//jqlTp8LV1RV9+vRR+D1dvnwZ/fv3R7t27biT7Z07d6JXr14ICwtDx44dFV6nLEuXLkVBQYHU9D179mDChAnw9/fHjz/+iLy8PGzbtg3du3fHv//+K5WQyTJz5kxYW1tj2bJlePr0KbZt24a4uDjuxE6R7Vy4cAEvX77ExIkTYW9vj+joaGzfvh3R0dG4efOm1MkMAPj4+GDatGkASk8eV61aVf0dpSLffPMNTp06hcmTJ+PBgwewsLDAuXPn8L///Q8rV65EmzZtFFpfRkYGgoODFXpN3759pU6q161bJzPp3b17N7Kzs/HZZ5+hoKAAmzZtQq9evfDgwQPUr18fABAdHY1u3bqhQYMGWLhwIczMzHDo0CF89NFHOHr0KIYMGSK1XiMjI4kLD1OmTJFaZteuXZg0aRJatGiBRYsWwdraGv/++y/+/vtvjBkzRuZ727lzJ7755husW7dOYpnDhw8jLy8Pn3zyCerUqYPbt29jy5YtSExMxOHDh7nlcnNzMWTIELi6uiI/Px+7du3C0KFDER4ezn0H5V2XuvD09MTKlSvx5ZdfYtiwYRg0aBByc3MRFBSEZs2aYcWKFZW+Xp59oohdu3bB3Nwcn3/+OczNzXH58mUsXboUWVlZWLNmjcLrU8ZvlzyqSjhZ1flOEqKRGEKIxsnMzGQAMIMHD5b7NQAYAMzdu3e5aXFxcYyxsTEzZMgQblpeXp7Ua8PDwxkAzO7du7lpO3fuZAAwr1694qY9e/aMAcCsXr2amzZhwgTGzMxMap2HDx9mADBXrlxhGIZhxGIx4+7uzvj7+zNisVginkaNGjF9+/blpn377bcMAObdu3cS67xz5w4DgNm5c6fE9hs2bMg9f/jwIaOnp8f0799fIv7s7GzG2tqamTp1qsQ6k5OTGSsrK6np5bH7o127dkxRURE3ffXq1QwA5sSJEwpvR9ZnsX//fgYAc/XqVal5DRo0YCZOnMg9v3LlisQ+ZhiGadiwIfPBBx9U+l5evXrFAGDWrFkjNa9FixZMz549pbZx+PDhCtdX/jNgGIZ58OABY2hoyEyZMoVJT09nGjRowLRv354pLi6uNDaGKT2Wv/32W+75V199xdSrV49p166dRGyVvf6zzz6Tmv7BBx9IxMnuBxMTEyYxMZGbfuvWLQYAM2/ePG5a7969mVatWjEFBQXcNLFYzHTt2pVxd3eX2taYMWMYc3NziWlmZmbMhAkTuOcZGRmMhYUF06lTJyY/P19i2bLfkZ49e3Lv+8yZM4y+vj7zxRdfSG1T1vEUHBzMCAQCJi4uTmoeKyUlhQHArF27VuF1yfv9Z99HixYtpJZds2aN1G9NVcexrGNfJBIx3bt3Z+rXr8+8f/+e+eyzzxh9fX3mzp07Fa6nIrL2iSLxy9p/06dPZ0xNTSWOIYFAwCxdulRiufK/vYr8pij6eZT9Pv31118MACYgIIApf/pY0+8kIZqKuvYRooGysrIAABYWFgq9rkuXLmjXrh333MXFBYMHD8a5c+e4LiUmJibc/OLiYqSmpsLNzQ3W1ta4d++e1DrT09Px/v17vHz5Ehs2bIBQKETPnj2llnv//r3Eo3w1p8jISDx//hxjxoxBamoqt1xubi569+6Nq1evQiwWS7wmLS1NYp2ZmZlV7oNFixbB29sbw4cPl5h+4cIFZGRkYPTo0RLrFAqF6NSpE65cuVLluoHS7mkGBgbc808++QT6+vr466+/FN5O2c+ioKAA79+/R+fOnQFA5mdRVFQEIyOjKmMsLi7G+/fvkZqaipKSkgqXy8vLk/rcync9YmVnZ+P9+/fIyMiocvtAaRfB5cuXY8eOHfD398f79+/xxx9/QF9fsY4Sr1+/xpYtW7BkyRKZXZaU4aOPPkKDBg245x07dkSnTp24zzQtLQ2XL1/GiBEjuP3A7l9/f388f/4cr1+/llhnQUEBjI2NK93uhQsXkJ2djYULF0otK6s18vbt2xgxYgSGDh0qs1Wj7PGUm5uL9+/fo2vXrmAYBv/++6/EsuwxEhMTgx9++AF6enro1q1btdYFVP39Z4lEIqll8/LyZC4r73HM0tPTw65du5CTk4P+/fvj559/xqJFi9C+ffsqX1t2exXtE0XiL7v/2GPGx8cHeXl5ePLkCTevXr16SExMrDSu6vx2yft5sBiGwaJFizB06FB06tSp0mVr4ztJiLqgrn2EaCBLS0sAULi0rLu7u9Q0Dw8P5OXl4d27d7C3t0d+fj6Cg4Oxc+dOvH79GkyZejSyEhVvb2/u/0ZGRti6datUV5fc3FzY2dlVGtvz588BABMmTKhwmczMTNjY2HD
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1000x500 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor\n",
|
|||
|
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"# Загрузка датасета\n",
|
|||
|
"df = pd.read_csv(\"..//static//csv//jio_mart_items.csv\").head(100)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод первых строк для проверки структуры\n",
|
|||
|
"print(df.head())\n",
|
|||
|
"\n",
|
|||
|
"# Целевая переменная\n",
|
|||
|
"target = df['price']\n",
|
|||
|
"\n",
|
|||
|
"# Удаление целевой переменной из признаков\n",
|
|||
|
"features = df.drop(columns=['price', 'href'])\n",
|
|||
|
"\n",
|
|||
|
"# Определение столбцов для обработки\n",
|
|||
|
"num_columns = features.select_dtypes(include=['number']).columns\n",
|
|||
|
"cat_columns = features.select_dtypes(include=['object']).columns\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг числовых столбцов\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline([\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг категориальных столбцов\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline([\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Объединение препроцессинга\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\"\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Создание финального пайплайна\n",
|
|||
|
"pipeline_end = Pipeline([\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Разделение данных на обучающую и тестовую выборки\n",
|
|||
|
"X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)\n",
|
|||
|
"\n",
|
|||
|
"# Применение пайплайна к данным\n",
|
|||
|
"X_train_processed = pipeline_end.fit_transform(X_train)\n",
|
|||
|
"X_test_processed = pipeline_end.transform(X_test)\n",
|
|||
|
"\n",
|
|||
|
"# 1. Настройка параметров для старых значений\n",
|
|||
|
"old_param_grid = {\n",
|
|||
|
" 'n_estimators': [50, 100, 200],\n",
|
|||
|
" 'max_depth': [None, 10, 20, 30],\n",
|
|||
|
" 'min_samples_split': [2, 5, 10]\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Подбор гиперпараметров с помощью Grid Search для старых параметров\n",
|
|||
|
"old_grid_search = GridSearchCV(estimator=RandomForestRegressor(),\n",
|
|||
|
" param_grid=old_param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели на тренировочных данных\n",
|
|||
|
"old_grid_search.fit(X_train_processed, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# Результаты подбора для старых параметров\n",
|
|||
|
"old_best_params = old_grid_search.best_params_\n",
|
|||
|
"old_best_mse = -old_grid_search.best_score_\n",
|
|||
|
"\n",
|
|||
|
"# 2. Настройка параметров для новых значений\n",
|
|||
|
"new_param_grid = {\n",
|
|||
|
" 'n_estimators': [200],\n",
|
|||
|
" 'max_depth': [10],\n",
|
|||
|
" 'min_samples_split': [10]\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Подбор гиперпараметров с помощью Grid Search для новых параметров\n",
|
|||
|
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(),\n",
|
|||
|
" param_grid=new_param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', cv=2)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели на тренировочных данных\n",
|
|||
|
"new_grid_search.fit(X_train_processed, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# Результаты подбора для новых параметров\n",
|
|||
|
"new_best_params = new_grid_search.best_params_\n",
|
|||
|
"new_best_mse = -new_grid_search.best_score_\n",
|
|||
|
"\n",
|
|||
|
"# 5. Обучение модели с лучшими параметрами для новых значений\n",
|
|||
|
"model_best = RandomForestRegressor(**new_best_params)\n",
|
|||
|
"model_best.fit(X_train_processed, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# Прогнозирование на тестовой выборке\n",
|
|||
|
"y_pred = model_best.predict(X_test_processed)\n",
|
|||
|
"\n",
|
|||
|
"# Оценка производительности модели\n",
|
|||
|
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
|
|||
|
"rmse = np.sqrt(mse)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод результатов\n",
|
|||
|
"print(\"Старые параметры:\", old_best_params)\n",
|
|||
|
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
|
|||
|
"print(\"\\nНовые параметры:\", new_best_params)\n",
|
|||
|
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
|
|||
|
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
|
|||
|
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели с лучшими параметрами для старых значений\n",
|
|||
|
"model_old = RandomForestRegressor(**old_best_params)\n",
|
|||
|
"model_old.fit(X_train_processed, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# Прогнозирование на тестовой выборке для старых параметров\n",
|
|||
|
"y_pred_old = model_old.predict(X_test_processed)\n",
|
|||
|
"\n",
|
|||
|
"# Визуализация ошибок\n",
|
|||
|
"plt.figure(figsize=(10, 5))\n",
|
|||
|
"plt.plot(y_test.values, label='Реальные значения', marker='o', linestyle='-', color='black')\n",
|
|||
|
"plt.plot(y_pred_old, label='Предсказанные значения (старые параметры)', marker='x', linestyle='--', color='blue')\n",
|
|||
|
"plt.plot(y_pred, label='Предсказанные значения (новые параметры)', marker='s', linestyle='--', color='orange')\n",
|
|||
|
"plt.xlabel('Объекты')\n",
|
|||
|
"plt.ylabel('Цена')\n",
|
|||
|
"plt.title('Сравнение реальных и предсказанных значений')\n",
|
|||
|
"plt.legend()\n",
|
|||
|
"plt.show()\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.9.7"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|