6626 lines
663 KiB
Plaintext
6626 lines
663 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Вариант 19: Данные о миллионерах"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Index(['Rank ', 'Name', 'Networth', 'Age', 'Country', 'Source', 'Industry'], dtype='object')\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd \n",
|
|||
|
"df = pd.read_csv(\"..//static//csv//Forbes Billionaires.csv\")\n",
|
|||
|
"print(df.columns)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Бизнес-цели\n",
|
|||
|
"\n",
|
|||
|
"### Задача классификации\n",
|
|||
|
"Классифицировать людей по уровню состояния.\n",
|
|||
|
"\n",
|
|||
|
"Цель:\n",
|
|||
|
"\n",
|
|||
|
"Разработать модель машинного обучения, которая будет классифицировать миллиардеров по состоянию, выше или ниже среднего.\n",
|
|||
|
"\n",
|
|||
|
"В обучении модели машинного обучения для классификации миллиардеров по уровню богатства, помимо чистого состояния, используются и другие столбцы данных:\n",
|
|||
|
"- Возраст: Люди с высоким чистым состоянием, как правило, старше. Модель может использовать возраст как признак, чтобы прогнозировать уровень богатства.\n",
|
|||
|
"- Страна: Богатство распределяется неравномерно по миру. Страна проживания может быть важным признаком для предсказания уровня богатства.\n",
|
|||
|
"- Отрасль: Определенные отрасли (например, финансы, технологии) часто связаны с высоким чистым состоянием. \n",
|
|||
|
"\n",
|
|||
|
"### Задача регрессии:\n",
|
|||
|
"Прогнозирование чистого состояния (Networth):\n",
|
|||
|
"\n",
|
|||
|
"Цель: Предсказать абсолютное значение чистого состояния миллиардера, используя информацию из имеющихся данных.\n",
|
|||
|
"\n",
|
|||
|
"Применение: Это может быть полезно для оценки потенциального состояния миллиардеров в будущем или для сравнения миллиардеров в разных странах и отраслях.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Определение достижимого уровня качества модели для первой задачи "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Создание целевой переменной и предварительная обработка данных"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Среднее значение поля 'Networth': 4.8607499999999995\n",
|
|||
|
" Rank Name Networth Age Country \\\n",
|
|||
|
"0 1 Elon Musk 219.0 50 United States \n",
|
|||
|
"1 2 Jeff Bezos 171.0 58 United States \n",
|
|||
|
"2 3 Bernard Arnault & family 158.0 73 France \n",
|
|||
|
"3 4 Bill Gates 129.0 66 United States \n",
|
|||
|
"4 5 Warren Buffett 118.0 91 United States \n",
|
|||
|
"\n",
|
|||
|
" Source Industry above_average_networth \n",
|
|||
|
"0 Tesla, SpaceX Automotive 1 \n",
|
|||
|
"1 Amazon Technology 1 \n",
|
|||
|
"2 LVMH Fashion & Retail 1 \n",
|
|||
|
"3 Microsoft Technology 1 \n",
|
|||
|
"4 Berkshire Hathaway Finance & Investments 1 \n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn import set_config\n",
|
|||
|
"\n",
|
|||
|
"# Установим параметры для вывода\n",
|
|||
|
"set_config(transform_output=\"pandas\")\n",
|
|||
|
"\n",
|
|||
|
"# Устанавливаем случайное состояние\n",
|
|||
|
"random_state = 42\n",
|
|||
|
"# Можно использовать данные о above_average_networth для анализа зависимости между типом источника богатства и чистым состоянием.\n",
|
|||
|
"# Рассчитываем среднее значение чистого состояния\n",
|
|||
|
"average_networth = df['Networth'].mean()\n",
|
|||
|
"print(f\"Среднее значение поля 'Networth': {average_networth}\")\n",
|
|||
|
"\n",
|
|||
|
"# Создаем новую переменную, указывающую, превышает ли чистое состояние среднее\n",
|
|||
|
"df['above_average_networth'] = (df['Networth'] > average_networth).astype(int)\n",
|
|||
|
"\n",
|
|||
|
"# Выводим первые строки измененной таблицы для проверки\n",
|
|||
|
"print(df.head())\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
|
|||
|
"\n",
|
|||
|
"Целевой признак -- above_average_networth "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 16,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"X_train shape: (2080, 8)\n",
|
|||
|
"y_train shape: (2080, 1)\n",
|
|||
|
"X_test shape: (520, 8)\n",
|
|||
|
"y_test shape: (520, 1)\n",
|
|||
|
"X_train:\n",
|
|||
|
" Rank Name Networth Age Country Source \\\n",
|
|||
|
"2125 2076 Yogesh Kothari 1.4 73 India specialty chemicals \n",
|
|||
|
"1165 1163 Yvonne Bauer 2.7 45 Germany magazines, media \n",
|
|||
|
"397 398 Juergen Blickle 6.4 75 Germany auto parts \n",
|
|||
|
"1432 1397 Alexander Svetakov 2.2 54 Russia real estate \n",
|
|||
|
"1024 1012 Li Min 3.0 56 China semiconductor \n",
|
|||
|
"\n",
|
|||
|
" Industry above_average_networth \n",
|
|||
|
"2125 Manufacturing 0 \n",
|
|||
|
"1165 Media & Entertainment 0 \n",
|
|||
|
"397 Manufacturing 1 \n",
|
|||
|
"1432 Finance & Investments 0 \n",
|
|||
|
"1024 Technology 0 \n",
|
|||
|
"y_train:\n",
|
|||
|
" above_average_networth\n",
|
|||
|
"2125 0\n",
|
|||
|
"1165 0\n",
|
|||
|
"397 1\n",
|
|||
|
"1432 0\n",
|
|||
|
"1024 0\n",
|
|||
|
"X_test:\n",
|
|||
|
" Rank Name Networth Age Country \\\n",
|
|||
|
"2437 2324 Horst Wortmann 1.2 80 Germany \n",
|
|||
|
"2118 2076 Ramesh Juneja 1.4 66 India \n",
|
|||
|
"1327 1292 Teresita Sy-Coson 2.4 71 Philippines \n",
|
|||
|
"2063 1929 Myron Wentz 1.5 82 St. Kitts and Nevis \n",
|
|||
|
"1283 1238 Suh Kyung-bae 2.5 59 South Korea \n",
|
|||
|
"\n",
|
|||
|
" Source Industry above_average_networth \n",
|
|||
|
"2437 footwear Fashion & Retail 0 \n",
|
|||
|
"2118 pharmaceuticals Healthcare 0 \n",
|
|||
|
"1327 diversified diversified 0 \n",
|
|||
|
"2063 health products Fashion & Retail 0 \n",
|
|||
|
"1283 cosmetics Fashion & Retail 0 \n",
|
|||
|
"y_test:\n",
|
|||
|
" above_average_networth\n",
|
|||
|
"2437 0\n",
|
|||
|
"2118 0\n",
|
|||
|
"1327 0\n",
|
|||
|
"2063 0\n",
|
|||
|
"1283 0\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from typing import Tuple\n",
|
|||
|
"from pandas import DataFrame\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"\n",
|
|||
|
"def split_stratified_into_train_val_test(\n",
|
|||
|
" df_input,\n",
|
|||
|
" stratify_colname=\"y\",\n",
|
|||
|
" frac_train=0.6,\n",
|
|||
|
" frac_val=0.15,\n",
|
|||
|
" frac_test=0.25,\n",
|
|||
|
" random_state=None,\n",
|
|||
|
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
" if frac_train + frac_val + frac_test != 1.0:\n",
|
|||
|
" raise ValueError(\n",
|
|||
|
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
|
|||
|
" % (frac_train, frac_val, frac_test)\n",
|
|||
|
" )\n",
|
|||
|
" if stratify_colname not in df_input.columns:\n",
|
|||
|
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
|
|||
|
" X = df_input # Contains all columns.\n",
|
|||
|
" y = df_input[\n",
|
|||
|
" [stratify_colname]\n",
|
|||
|
" ] # Dataframe of just the column on which to stratify.\n",
|
|||
|
" # Split original dataframe into train and temp dataframes.\n",
|
|||
|
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
|
|||
|
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
|
|||
|
" )\n",
|
|||
|
" if frac_val <= 0:\n",
|
|||
|
" assert len(df_input) == len(df_train) + len(df_temp)\n",
|
|||
|
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
|
|||
|
" # Split the temp dataframe into val and test dataframes.\n",
|
|||
|
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
|
|||
|
" df_val, df_test, y_val, y_test = train_test_split(\n",
|
|||
|
" df_temp,\n",
|
|||
|
" y_temp,\n",
|
|||
|
" stratify=y_temp,\n",
|
|||
|
" test_size=relative_frac_test,\n",
|
|||
|
" random_state=random_state,\n",
|
|||
|
" )\n",
|
|||
|
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
|
|||
|
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
|
|||
|
"\n",
|
|||
|
"# Разделение набора данных на обучающую, валидационную и тестовую выборки (80/0/20)\n",
|
|||
|
"random_state = 42 # Задайте любое целое число для воспроизводимости\n",
|
|||
|
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
|
|||
|
" df, stratify_colname=\"above_average_networth\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Вывод размеров выборок\n",
|
|||
|
"print(\"X_train shape:\", X_train.shape)\n",
|
|||
|
"print(\"y_train shape:\", y_train.shape)\n",
|
|||
|
"print(\"X_test shape:\", X_test.shape)\n",
|
|||
|
"print(\"y_test shape:\", y_test.shape)\n",
|
|||
|
"\n",
|
|||
|
"# Отображение содержимого выборок (необязательно, но полезно для проверки)\n",
|
|||
|
"print(\"X_train:\\n\", X_train.head())\n",
|
|||
|
"print(\"y_train:\\n\", y_train.head())\n",
|
|||
|
"print(\"X_test:\\n\", X_test.head())\n",
|
|||
|
"print(\"y_test:\\n\", y_test.head())\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Формирование конвейера для классификации данных\n",
|
|||
|
"\n",
|
|||
|
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
|
|||
|
"\n",
|
|||
|
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
|
|||
|
"\n",
|
|||
|
"features_preprocessing -- трансформер для предобработки признаков\n",
|
|||
|
"\n",
|
|||
|
"features_engineering -- трансформер для конструирования признаков\n",
|
|||
|
"\n",
|
|||
|
"drop_columns -- трансформер для удаления колонок\n",
|
|||
|
"\n",
|
|||
|
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 17,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# Определение столбцов для обработки\n",
|
|||
|
"columns_to_drop = [\"Name\", \"Rank \"] # Столбцы, которые можно удалить\n",
|
|||
|
"num_columns = [\"Networth\", \"Age\"] # Числовые столбцы\n",
|
|||
|
"cat_columns = [\"Country\", \"Source\", \"Industry\"] # Категориальные столбцы\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг числовых столбцов\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг категориальных столбцов\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Объединение препроцессинга\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\"\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Удаление ненужных столбцов\n",
|
|||
|
"drop_columns = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\",\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Создание финального пайплайна\n",
|
|||
|
"pipeline_end = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" (\"drop_columns\", drop_columns),\n",
|
|||
|
" ]\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"__Демонстрация работы конвейера__"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 18,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Networth Age Country_Argentina Country_Australia \\\n",
|
|||
|
"2125 -0.340947 0.680013 0.0 0.0 \n",
|
|||
|
"1165 -0.211625 -1.475070 0.0 0.0 \n",
|
|||
|
"397 0.156447 0.833948 0.0 0.0 \n",
|
|||
|
"1432 -0.261364 -0.782365 0.0 0.0 \n",
|
|||
|
"1024 -0.181781 -0.628430 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Country_Austria Country_Barbados Country_Belgium Country_Belize \\\n",
|
|||
|
"2125 0.0 0.0 0.0 0.0 \n",
|
|||
|
"1165 0.0 0.0 0.0 0.0 \n",
|
|||
|
"397 0.0 0.0 0.0 0.0 \n",
|
|||
|
"1432 0.0 0.0 0.0 0.0 \n",
|
|||
|
"1024 0.0 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Country_Brazil Country_Bulgaria ... Industry_Manufacturing \\\n",
|
|||
|
"2125 0.0 0.0 ... 1.0 \n",
|
|||
|
"1165 0.0 0.0 ... 0.0 \n",
|
|||
|
"397 0.0 0.0 ... 1.0 \n",
|
|||
|
"1432 0.0 0.0 ... 0.0 \n",
|
|||
|
"1024 0.0 0.0 ... 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Media & Entertainment Industry_Metals & Mining \\\n",
|
|||
|
"2125 0.0 0.0 \n",
|
|||
|
"1165 1.0 0.0 \n",
|
|||
|
"397 0.0 0.0 \n",
|
|||
|
"1432 0.0 0.0 \n",
|
|||
|
"1024 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Real Estate Industry_Service Industry_Sports \\\n",
|
|||
|
"2125 0.0 0.0 0.0 \n",
|
|||
|
"1165 0.0 0.0 0.0 \n",
|
|||
|
"397 0.0 0.0 0.0 \n",
|
|||
|
"1432 0.0 0.0 0.0 \n",
|
|||
|
"1024 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Technology Industry_Telecom Industry_diversified \\\n",
|
|||
|
"2125 0.0 0.0 0.0 \n",
|
|||
|
"1165 0.0 0.0 0.0 \n",
|
|||
|
"397 0.0 0.0 0.0 \n",
|
|||
|
"1432 0.0 0.0 0.0 \n",
|
|||
|
"1024 1.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" above_average_networth \n",
|
|||
|
"2125 0 \n",
|
|||
|
"1165 0 \n",
|
|||
|
"397 1 \n",
|
|||
|
"1432 0 \n",
|
|||
|
"1024 0 \n",
|
|||
|
"\n",
|
|||
|
"[5 rows x 859 columns]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Вывод первых строк обработанных данных\n",
|
|||
|
"print(preprocessed_df.head())"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Формирование набора моделей для классификации\n",
|
|||
|
"\n",
|
|||
|
"logistic -- логистическая регрессия\n",
|
|||
|
"\n",
|
|||
|
"ridge -- гребневая регрессия\n",
|
|||
|
"\n",
|
|||
|
"decision_tree -- дерево решений\n",
|
|||
|
"\n",
|
|||
|
"knn -- k-ближайших соседей\n",
|
|||
|
"\n",
|
|||
|
"naive_bayes -- наивный Байесовский классификатор\n",
|
|||
|
"\n",
|
|||
|
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
|
|||
|
"\n",
|
|||
|
"random_forest -- метод случайного леса (набор деревьев решений)\n",
|
|||
|
"\n",
|
|||
|
"mlp -- многослойный персептрон (нейронная сеть)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 20,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn import linear_model, tree, neighbors, naive_bayes, ensemble, neural_network\n",
|
|||
|
"class_models = {\n",
|
|||
|
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
|
|||
|
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
|
|||
|
" \"decision_tree\": {\n",
|
|||
|
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=42)\n",
|
|||
|
" },\n",
|
|||
|
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
|
|||
|
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
|
|||
|
" \"gradient_boosting\": {\n",
|
|||
|
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
|
|||
|
" },\n",
|
|||
|
" \"random_forest\": {\n",
|
|||
|
" \"model\": ensemble.RandomForestClassifier(\n",
|
|||
|
" max_depth=11, class_weight=\"balanced\", random_state=42\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" \"mlp\": {\n",
|
|||
|
" \"model\": neural_network.MLPClassifier(\n",
|
|||
|
" hidden_layer_sizes=(7,),\n",
|
|||
|
" max_iter=500,\n",
|
|||
|
" early_stopping=True,\n",
|
|||
|
" random_state=42,\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
"}"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Обучение моделей и оценка их качества"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 21,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: logistic\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: ridge\n",
|
|||
|
"Model: decision_tree\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: knn\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: naive_bayes\n",
|
|||
|
"Model: gradient_boosting\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: random_forest\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: mlp\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"\n",
|
|||
|
"for model_name in class_models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" model = class_models[model_name][\"model\"]\n",
|
|||
|
"\n",
|
|||
|
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
|
|||
|
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
|
|||
|
"\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train)\n",
|
|||
|
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
|
|||
|
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
|
|||
|
"\n",
|
|||
|
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
|||
|
" class_models[model_name][\"probs\"] = y_test_probs\n",
|
|||
|
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
|||
|
"\n",
|
|||
|
" # Оценка метрик\n",
|
|||
|
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
|
|||
|
" y_test, y_test_probs\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Сводная таблица оценок качества для использованных моделей классификации"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 22,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4UAAAQ9CAYAAADu7ug2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVwU5R8H8M9yLcgNcoiioihHHgRakQfeaJ5peWGiqZV5l+ZRHqhEaeZ95oEUpqVl3qYmpGbemimSeEEKaqIoKtfu8/uDmJ8bsoAu7O7wef9e8/q58zw78501+PjMPjOjEEIIEBERERERUYVkou8CiIiIiIiISH84KCQiIiIiIqrAOCgkIiIiIiKqwDgoJCIiIiIiqsA4KCQiIiIiIqrAOCgkIiIiIiKqwDgoJCIiIiIiqsA4KCQiIiIiIqrAOCgkIiIiIiKqwDgoJNKx6OhoKBQKXL16tUy2f/XqVSgUCkRHR+tke3FxcVAoFIiLi9PJ9oiIiORi2rRpUCgUJeqrUCgwbdq0si2IqIxwUEhUQSxZskRnA0kiIiIikg8zfRdARKVTo0YNPH78GObm5qV635IlS1C5cmUMGDBAY33z5s3x+PFjWFhY6LBKIiIi4/fJJ59gwoQJ+i6DqMxxUEhkZBQKBSwtLXW2PRMTE51uj4iISA4ePnwIa2trmJnxn8skf5w+SlQOlixZghdeeAFKpRIeHh4YNmwY7t27V6jf4sWLUatWLVhZWeGll17CgQMH0KJFC7Ro0ULq87RrCtPS0jBw4EBUq1YNSqUSVapUQdeuXaXrGmvWrIlz584hPj4eCoUCCoVC2mZR1xQeOXIEr732GhwdHWFtbY0GDRpg/vz5uv1giIiIDEDBtYPnz59H37594ejoiKZNmz71msLs7GyMGTMGLi4usLW1RZcuXfD3338/dbtxcXFo1KgRLC0tUbt2bSxfvrzI6xS/+eYbBAUFwcrKCk5OTujduzdSUlLK5HiJ/ounPojK2LRp0xAREYE2bdpg6NChSExMxNKlS3Hs2DEcOnRImga6dOlSDB8+HM2aNcOYMWNw9epVdOvWDY6OjqhWrZrWffTo0QPnzp3DiBEjULNmTdy6dQt79uxBcnIyatasiXnz5mHEiBGwsbHBxx9/DABwc3Mrcnt79uxBp06dUKVKFYwaNQru7u5ISEjAtm3bMGrUKN19OERERAbkzTffRJ06dfDpp59CCIFbt24V6jN48GB888036Nu3L1599VX88ssv6NixY6F+p06dQvv27VGlShVERERApVJh+vTpcHFxKdQ3MjISkydPRs+ePTF48GDcvn0bCxcuRPPmzXHq1Ck4ODiUxeES/Z8gIp1as2aNACCuXLkibt26JSwsLES7du2ESqWS+ixatEgAEKtXrxZCCJGdnS2cnZ1F48aNRW5urtQvOjpaABAhISHSuitXrggAYs2aNUIIIe7evSsAiNmzZ2ut64UXXtDYToH9+/cLAGL//v1CCCHy8vKEl5eXqFGjhrh7965GX7VaXfIPgoiIyEhMnTpVABB9+vR56voCp0+fFgDE+++/r9Gvb9++AoCYOnWqtK5z586iUqVK4vr169K6ixcvCjMzM41tXr16VZiamorIyEiNbZ49e1aYmZkVWk9UFjh9lKgM7d27Fzk5ORg9ejRMTP7/4zZkyBDY2dlh+/btAIDjx4/jzp07GDJkiMa1C2FhYXB0dNS6DysrK1hYWCAuLg5379597ppPnTqFK1euYPTo0YXOTJb0ttxERETG6L333tPavmPHDgDAyJEjNdaPHj1a47VKpcLevXvRrVs3eHh4SOu9vb3RoUMHjb4//PAD1Go1evbsiX/++Uda3N3dUadOHezfv/85joioZDh9lKgMXbt2DQDg4+Ojsd7CwgK1atWS2gv+39vbW6OfmZkZatasqXUfSqUSn3/+OT788EO4ubnhlVdeQadOndC/f3+4u7uXuuZLly4BAOrVq1fq9xIRERkzLy8vre3Xrl2DiYkJateurbH+vzl/69YtPH78uFCuA4Wz/uLFixBCoE6dOk/dZ2nvNk70LDgoJJKB0aNHo3Pnzti8eTN2796NyZMnIyoqCr/88gtefPFFfZdHRERkFKysrMp9n2q1GgqFAjt37oSpqWmhdhsbm3KviSoeTh8lKkM1atQAACQmJmqsz8nJwZUrV6T2gv9PSkrS6JeXlyfdQbQ4tWvXxocffoiff/4Zf/75J3JycjBnzhypvaRTPwvOfv75558l6k9ERFRR1KhRA2q1WppVU+C/Oe/q6gpLS8tCuQ4UzvratWtDCAEvLy+0adOm0PLKK6/o/kCI/oODQqIy1KZNG1hYWGDBggUQQkjrV61ahYyMDOluZY0aNYKzszO++uor5OXlSf1iY2OLvU7w0aNHyMrK0lhXu3Zt2NraIjs7W1pnbW391Mdg/FdgYCC8vLwwb968Qv2fPAYiIqKKpuB6wAULFmisnzdvnsZrU1NTtGnTBps3b8aNGzek9UlJSdi5c6dG3+7du8PU1BQRERGFclYIgTt37ujwCIiejtNHicqQi4sLJk6ciIiICLRv3x5dunRBYmIilixZgsaNG6Nfv34A8q8xnDZtGkaMGIFWrVqhZ8+euHr1KqKjo1G7dm2t3/L99ddfaN26NXr27Al/f3+YmZnhxx9/xM2bN9G7d2+pX1BQEJYuXYqZM2fC29sbrq6uaNWqVaHtmZiYYOnSpejcuTMCAgIwcOBAVKlSBRcuXMC5c+ewe/du3X9QRERERiAgIAB9+vTBkiVLkJGRgVdffRX79u176jeC06ZNw88//4wmTZpg6NChUKlUWLRoEerVq4fTp09L/WrXro2ZM2di4sSJ0uOobG1tceXKFfz444945513MHbs2HI8SqqIOCgkKmPTpk2Di4sLFi1ahDFjxsDJyQnvvPMOPv30U42Lx4cPHw4hBObMmYOxY8eiYcOG2LJlC0aOHAlLS8sit+/p6Yk+ffpg3759+Prrr2FmZgZfX19899136NGjh9RvypQpuHbtGmbNmoUHDx4gJCTkqYNCAAgNDcX+/fsRERGBOXPmQK1Wo3bt2hgyZIjuPhgiIiIjtHr1ari4uCA2NhabN29Gq1atsH37dnh6emr0CwoKws6dOzF27FhMnjwZnp6emD59OhISEnDhwgWNvhMmTEDdunUxd+5cREREAMjP93bt2qFLly7ldmxUcSkE54MRGSy1Wg0XFxd0794dX331lb7LISIioufUrVs3nDt3DhcvXtR3KUQSXlNIZCCysrIKXUsQExOD9PR0tGjRQj9FERER0TN7/PixxuuLFy9ix44dzHUyOPymkMhAxMXFYcyYMXjzzTfh7OyMkydPYtWqVfDz88OJEydgYWGh7xKJiIioFKpUqYIBAwZIzyZeunQpsrOzcerUqSKfS0ikD7ymkMhA1KxZE56enliwYAHS09Ph5OSE/v3747PPPuOAkIiIyAi1b98e3377LdLS0qBUKhEcHIxPP/2UA0IyOPymkIiIiIiIqALjNYVERBXQZ599BoVCgdGjR0vrsrKyMGzYMDg7O8PGxgY9evTAzZs3Nd6XnJyMjh07olKlSnB1dcW4ceM0nq1JREREz0af2cxBIRFRBXPs2DEsX74cDRo00Fg/ZswYbN26Fd9//z3i4+Nx48YNdO/eXWpXqVTo2LEjcnJy8Ntvv2Ht2rWIjo7GlClTyvsQiIiIZEXf2czpo/Rc1Go1bty4AVtbW60PWCeSIyEEHjx4AA8PD5iY6PYcW1ZWFnJycortZ2FhofU5lv+VmZmJwMBALFmyBDNnzkRAQADmzZuHjIwMuLi4YN26dXjjjTcAABcuXICfnx8OHz6MV155BTt37kSnTp1w48YNuLm5AQCWLVuG8ePH4/bt27z2lchAMJupImM2P1s280Yz9Fxu3LhR6GGtRBVNSkoKqlWrprPtZWVlwauGDdJuqYrt6+7ujjNnzmiEj1KphFKpfGr/YcOGoWPHjmjTpg1mzpwprT9x4gRyc3PRpk0baZ2vry+qV68uBc/hw4dRv359KXQAIDQ0FEOHDsW
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1200x1000 with 16 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
|
|||
|
"\n",
|
|||
|
"for index, key in enumerate(class_models.keys()):\n",
|
|||
|
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[\"Below Average\", \"Above Average\"] # Измените метки на нужные\n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
" disp.ax_.set_title(key)\n",
|
|||
|
"\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"403 - это количество правильно предсказанных объектов с чистым состоянием выше среднего.\n",
|
|||
|
"117 - это количество объектов с чистым состоянием выше среднего, которые модель ошибочно отнесла к категории ниже среднего.\n",
|
|||
|
"1. Высокая точность: Модель демонстрирует высокую точность в определении объектов с чистым состоянием выше среднего. Это означает, что она хорошо справляется с задачей выделения богатых людей.\n",
|
|||
|
"2. Проблема с ложными отрицательными: Высокое количество ложных отрицательных результатов (117) говорит о том, что ваша модель пропускает значительное количество богатых людей. Она не всегда распознает их как \"выше среднего\".\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Точность, полнота, верность (аккуратность), F-мера"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 23,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_36985_row0_col0, #T_36985_row1_col0, #T_36985_row2_col0, #T_36985_row3_col0, #T_36985_row4_col0, #T_36985_row5_col0, #T_36985_row6_col0, #T_36985_row7_col0 {\n",
|
|||
|
" background-color: #440154;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_36985_row0_col1, #T_36985_row0_col2, #T_36985_row0_col3, #T_36985_row1_col1, #T_36985_row1_col2, #T_36985_row1_col3, #T_36985_row2_col1, #T_36985_row2_col2, #T_36985_row2_col3, #T_36985_row3_col1, #T_36985_row3_col2, #T_36985_row3_col3, #T_36985_row4_col1, #T_36985_row4_col2, #T_36985_row4_col3, #T_36985_row5_col1, #T_36985_row6_col2, #T_36985_row7_col1 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_36985_row0_col4, #T_36985_row0_col5, #T_36985_row0_col6, #T_36985_row0_col7, #T_36985_row1_col4, #T_36985_row1_col5, #T_36985_row1_col6, #T_36985_row1_col7, #T_36985_row2_col4, #T_36985_row2_col5, #T_36985_row2_col6, #T_36985_row2_col7, #T_36985_row3_col4, #T_36985_row3_col5, #T_36985_row3_col6, #T_36985_row3_col7, #T_36985_row4_col4, #T_36985_row4_col5, #T_36985_row4_col6, #T_36985_row4_col7, #T_36985_row6_col4, #T_36985_row6_col6 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_36985_row5_col2 {\n",
|
|||
|
" background-color: #a0da39;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_36985_row5_col3 {\n",
|
|||
|
" background-color: #8ed645;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_36985_row5_col4, #T_36985_row5_col6 {\n",
|
|||
|
" background-color: #d7566c;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_36985_row5_col5 {\n",
|
|||
|
" background-color: #d14e72;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_36985_row5_col7 {\n",
|
|||
|
" background-color: #d24f71;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_36985_row6_col1, #T_36985_row7_col2, #T_36985_row7_col3 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_36985_row6_col3 {\n",
|
|||
|
" background-color: #9bd93c;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_36985_row6_col5 {\n",
|
|||
|
" background-color: #a11b9b;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_36985_row6_col7 {\n",
|
|||
|
" background-color: #aa2395;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_36985_row7_col4, #T_36985_row7_col5, #T_36985_row7_col6, #T_36985_row7_col7 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_36985\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_36985_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
|||
|
" <th id=\"T_36985_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
|||
|
" <th id=\"T_36985_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
|||
|
" <th id=\"T_36985_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
|||
|
" <th id=\"T_36985_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
|||
|
" <th id=\"T_36985_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_36985_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
|||
|
" <th id=\"T_36985_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_36985_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
|
|||
|
" <td id=\"T_36985_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_36985_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
|
|||
|
" <td id=\"T_36985_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_36985_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
|
|||
|
" <td id=\"T_36985_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_36985_level0_row3\" class=\"row_heading level0 row3\" >gradient_boosting</th>\n",
|
|||
|
" <td id=\"T_36985_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row3_col5\" class=\"data row3 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row3_col6\" class=\"data row3 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row3_col7\" class=\"data row3 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_36985_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_36985_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row4_col5\" class=\"data row4 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row4_col6\" class=\"data row4 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row4_col7\" class=\"data row4 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_36985_level0_row5\" class=\"row_heading level0 row5\" >mlp</th>\n",
|
|||
|
" <td id=\"T_36985_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row5_col2\" class=\"data row5 col2\" >0.995726</td>\n",
|
|||
|
" <td id=\"T_36985_row5_col3\" class=\"data row5 col3\" >0.982906</td>\n",
|
|||
|
" <td id=\"T_36985_row5_col4\" class=\"data row5 col4\" >0.999038</td>\n",
|
|||
|
" <td id=\"T_36985_row5_col5\" class=\"data row5 col5\" >0.996154</td>\n",
|
|||
|
" <td id=\"T_36985_row5_col6\" class=\"data row5 col6\" >0.997859</td>\n",
|
|||
|
" <td id=\"T_36985_row5_col7\" class=\"data row5 col7\" >0.991379</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_36985_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
|
|||
|
" <td id=\"T_36985_row6_col0\" class=\"data row6 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row6_col1\" class=\"data row6 col1\" >0.920635</td>\n",
|
|||
|
" <td id=\"T_36985_row6_col2\" class=\"data row6 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row6_col3\" class=\"data row6 col3\" >0.991453</td>\n",
|
|||
|
" <td id=\"T_36985_row6_col4\" class=\"data row6 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row6_col5\" class=\"data row6 col5\" >0.978846</td>\n",
|
|||
|
" <td id=\"T_36985_row6_col6\" class=\"data row6 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row6_col7\" class=\"data row6 col7\" >0.954733</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_36985_level0_row7\" class=\"row_heading level0 row7\" >knn</th>\n",
|
|||
|
" <td id=\"T_36985_row7_col0\" class=\"data row7 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row7_col1\" class=\"data row7 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_36985_row7_col2\" class=\"data row7 col2\" >0.848291</td>\n",
|
|||
|
" <td id=\"T_36985_row7_col3\" class=\"data row7 col3\" >0.811966</td>\n",
|
|||
|
" <td id=\"T_36985_row7_col4\" class=\"data row7 col4\" >0.965865</td>\n",
|
|||
|
" <td id=\"T_36985_row7_col5\" class=\"data row7 col5\" >0.957692</td>\n",
|
|||
|
" <td id=\"T_36985_row7_col6\" class=\"data row7 col6\" >0.917919</td>\n",
|
|||
|
" <td id=\"T_36985_row7_col7\" class=\"data row7 col7\" >0.896226</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x29dd91458e0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 23,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
|||
|
" [\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" \"Accuracy_train\",\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_train\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]\n",
|
|||
|
"class_metrics.sort_values(\n",
|
|||
|
" by=\"Accuracy_test\", ascending=False\n",
|
|||
|
").style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Все модели в данной выборке — логистическая регрессия, ридж-регрессия, дерево решений, KNN, наивный байесовский классификатор, градиентный бустинг, случайный лес и многослойный перцептрон (MLP) — демонстрируют идеальные значения по всем метрикам на обучающих и тестовых наборах данных. Это достигается, поскольку все модели показали значения, равные 1.0 для Precision, Recall, Accuracy и F1-меры, что указывает на то, что модель безошибочно классифицирует все примеры.\n",
|
|||
|
"\n",
|
|||
|
"Модель MLP, хотя и имеет немного более низкие значения Recall (0.994) и F1-на тестовом наборе (0.997) по сравнению с другими, по-прежнему остается высокоэффективной. Тем не менее, она не снижает показатели классификации до такого уровня, что может вызвать обеспокоенность, и остается на уровне, близком к идеальному."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 24,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_b46dc_row0_col0, #T_b46dc_row0_col1, #T_b46dc_row1_col0, #T_b46dc_row1_col1, #T_b46dc_row2_col0, #T_b46dc_row2_col1, #T_b46dc_row3_col0, #T_b46dc_row3_col1, #T_b46dc_row5_col0, #T_b46dc_row5_col1 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_b46dc_row0_col2, #T_b46dc_row0_col3, #T_b46dc_row0_col4, #T_b46dc_row1_col2, #T_b46dc_row1_col3, #T_b46dc_row1_col4, #T_b46dc_row2_col2, #T_b46dc_row2_col3, #T_b46dc_row2_col4, #T_b46dc_row3_col2, #T_b46dc_row3_col3, #T_b46dc_row3_col4, #T_b46dc_row4_col2, #T_b46dc_row5_col2, #T_b46dc_row5_col3, #T_b46dc_row5_col4 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_b46dc_row4_col0 {\n",
|
|||
|
" background-color: #8ed645;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_b46dc_row4_col1 {\n",
|
|||
|
" background-color: #90d743;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_b46dc_row4_col3 {\n",
|
|||
|
" background-color: #d24f71;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_b46dc_row4_col4 {\n",
|
|||
|
" background-color: #d14e72;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_b46dc_row6_col0, #T_b46dc_row6_col1 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_b46dc_row6_col2 {\n",
|
|||
|
" background-color: #cd4a76;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_b46dc_row6_col3, #T_b46dc_row6_col4, #T_b46dc_row7_col2 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_b46dc_row7_col0 {\n",
|
|||
|
" background-color: #2fb47c;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_b46dc_row7_col1 {\n",
|
|||
|
" background-color: #3bbb75;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_b46dc_row7_col3 {\n",
|
|||
|
" background-color: #a72197;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_b46dc_row7_col4 {\n",
|
|||
|
" background-color: #a51f99;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_b46dc\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_b46dc_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_b46dc_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
|||
|
" <th id=\"T_b46dc_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
|||
|
" <th id=\"T_b46dc_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
|||
|
" <th id=\"T_b46dc_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_b46dc_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
|
|||
|
" <td id=\"T_b46dc_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_b46dc_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
|
|||
|
" <td id=\"T_b46dc_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_b46dc_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
|
|||
|
" <td id=\"T_b46dc_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_b46dc_level0_row3\" class=\"row_heading level0 row3\" >gradient_boosting</th>\n",
|
|||
|
" <td id=\"T_b46dc_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_b46dc_level0_row4\" class=\"row_heading level0 row4\" >mlp</th>\n",
|
|||
|
" <td id=\"T_b46dc_row4_col0\" class=\"data row4 col0\" >0.996154</td>\n",
|
|||
|
" <td id=\"T_b46dc_row4_col1\" class=\"data row4 col1\" >0.991379</td>\n",
|
|||
|
" <td id=\"T_b46dc_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row4_col3\" class=\"data row4 col3\" >0.988904</td>\n",
|
|||
|
" <td id=\"T_b46dc_row4_col4\" class=\"data row4 col4\" >0.988965</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_b46dc_level0_row5\" class=\"row_heading level0 row5\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_b46dc_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_b46dc_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_b46dc_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
|
|||
|
" <td id=\"T_b46dc_row6_col0\" class=\"data row6 col0\" >0.957692</td>\n",
|
|||
|
" <td id=\"T_b46dc_row6_col1\" class=\"data row6 col1\" >0.896226</td>\n",
|
|||
|
" <td id=\"T_b46dc_row6_col2\" class=\"data row6 col2\" >0.997858</td>\n",
|
|||
|
" <td id=\"T_b46dc_row6_col3\" class=\"data row6 col3\" >0.870015</td>\n",
|
|||
|
" <td id=\"T_b46dc_row6_col4\" class=\"data row6 col4\" >0.877459</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_b46dc_level0_row7\" class=\"row_heading level0 row7\" >naive_bayes</th>\n",
|
|||
|
" <td id=\"T_b46dc_row7_col0\" class=\"data row7 col0\" >0.978846</td>\n",
|
|||
|
" <td id=\"T_b46dc_row7_col1\" class=\"data row7 col1\" >0.954733</td>\n",
|
|||
|
" <td id=\"T_b46dc_row7_col2\" class=\"data row7 col2\" >0.983320</td>\n",
|
|||
|
" <td id=\"T_b46dc_row7_col3\" class=\"data row7 col3\" >0.940955</td>\n",
|
|||
|
" <td id=\"T_b46dc_row7_col4\" class=\"data row7 col4\" >0.942055</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x29ddbb10da0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 24,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"\n",
|
|||
|
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
|||
|
" [\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]\n",
|
|||
|
"\n",
|
|||
|
"# Сортировка по ROC_AUC_test в порядке убывания\n",
|
|||
|
"class_metrics = class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False)\n",
|
|||
|
"\n",
|
|||
|
"# Применение стилей\n",
|
|||
|
"class_metrics.style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\", # Цветовая палитра для ROC_AUC_test, MCC_test, Cohen_kappa_test\n",
|
|||
|
" low=0.3, # Минимальное значение для цветового градиента\n",
|
|||
|
" high=1, # Максимальное значение для цветового градиента\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" ],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\", # Цветовая палитра для Accuracy_test, F1_test\n",
|
|||
|
" low=1, # Минимальное значение для цветового градиента\n",
|
|||
|
" high=0.3, # Максимальное значение для цветового градиента\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Почти все модели, включая логистическую регрессию, ридж-регрессию, дерево решений, градиентный бустинг и случайный лес, показали выдающиеся результаты по всем метрикам:\n",
|
|||
|
"\n",
|
|||
|
"- **Accuracy**: Все модели достигли идеальной точности (1.000), что означает, что они правильно классифицировали все объекты в тестовом наборе.\n",
|
|||
|
"- **F1**: Аналогично, все модели показали идеальное значение F1-меры (1.000), что говорит о балансе между точностью (precision) и полнотой (recall).\n",
|
|||
|
"- **ROC AUC**: Все модели достигли максимального значения ROC AUC (1.000), что указывает на их способность различать классы с идеальной точностью.\n",
|
|||
|
"- **Cohen's Kappa**: Идеальное значение Cohen's Kappa (1.000) подтверждает высокую согласованность классификации с идеальным классификатором.\n",
|
|||
|
"- **MCC**: Идеальное значение MCC (1.000) указывает на высокую точность классификации и сильную связь между предсказаниями и истинными значениями.\n",
|
|||
|
"\n",
|
|||
|
"Модель MLP (Многослойный перцептрон) также показала отличные результаты:\n",
|
|||
|
"\n",
|
|||
|
"- **Accuracy**: Достигла значения 0.996, что немного ниже идеального, но все еще очень высокий результат.\n",
|
|||
|
"- **F1**: Значение F1-меры равно 0.991, что также указывает на высокую эффективность модели.\n",
|
|||
|
"- **ROC AUC**: MLP достигает идеального значения ROC AUC (1.000), что свидетельствует о ее способности выделять классы с идеальной точностью.\n",
|
|||
|
"- **Cohen's Kappa**: Высокое значение Cohen's Kappa (0.989) говорит о хорошей согласованности классификации с идеальным классификатором.\n",
|
|||
|
"- **MCC**: Высокое значение MCC (0.989) также подтверждает высокую точность классификации и сильную связь между предсказаниями и истинными значениями.\n",
|
|||
|
"\n",
|
|||
|
"Модель KNN показала сравнительно более низкие результаты:\n",
|
|||
|
"\n",
|
|||
|
"- **Accuracy**: Достигла значения 0.958, что ниже идеального, но все еще является приемлемым результатом.\n",
|
|||
|
"- **F1**: Значение F1-меры равно 0.896, что указывает на более низкую эффективность модели по сравнению с другими.\n",
|
|||
|
"- **ROC AUC**: KNN достигает значения ROC AUC 0.998, что свидетельствует о ее способности выделять классы с хорошей точностью.\n",
|
|||
|
"- **Cohen's Kappa**: Значение Cohen's Kappa (0.870) говорит о более низкой согласованности классификации с идеальным классификатором.\n",
|
|||
|
"- **MCC**: Значение MCC (0.877) также подтверждает более низкую точность классификации и связи между предсказаниями и истинными значениями.\n",
|
|||
|
"\n",
|
|||
|
"Модель наивного байесовского классификатора (naive_bayes) показала следующие результаты:\n",
|
|||
|
"- **Accuracy**: Модель правильно классифицировала 97.88% объектов в тестовом наборе. Это довольно хороший результат, но не идеальный.\n",
|
|||
|
"- **F1-мера**: Значение F1-меры 0.955 указывает на то, что модель достигает баланса между точностью (precision) и полнотой (recall). Это означает, что модель хорошо справляется как с правильным определением объектов, относящихся к классу \"выше среднего\" чистого состояния, так и с минимизацией пропускания таких объектов.\n",
|
|||
|
"- **ROC AUC**: Модель достигла значения ROC AUC 0.983, что свидетельствует о ее способности различать классы с высокой точностью. \n",
|
|||
|
"- **Cohen's Kappa**: Значение 0.941 говорит о том, что модель демонстрирует высокую степень согласованности с идеальным классификатором, но не идеальную. \n",
|
|||
|
"- **MCC**: MCC 0.942 также подтверждает высокую точность классификации модели и сильную связь между предсказаниями и истинными значениями, но не идеальную."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 25,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'logistic'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
|
|||
|
"\n",
|
|||
|
"display(best_model)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Вывод данных с ошибкой предсказания для оценки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 26,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'Error items count: 0'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Rank</th>\n",
|
|||
|
" <th>Predicted</th>\n",
|
|||
|
" <th>Name</th>\n",
|
|||
|
" <th>Networth</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Country</th>\n",
|
|||
|
" <th>Source</th>\n",
|
|||
|
" <th>Industry</th>\n",
|
|||
|
" <th>above_average_networth</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"Empty DataFrame\n",
|
|||
|
"Columns: [Rank , Predicted, Name, Networth, Age, Country, Source, Industry, above_average_networth]\n",
|
|||
|
"Index: []"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Преобразование тестовых данных\n",
|
|||
|
"preprocessing_result = pipeline_end.transform(X_test)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Получение предсказаний лучшей модели\n",
|
|||
|
"y_pred = class_models[best_model][\"preds\"]\n",
|
|||
|
"\n",
|
|||
|
"# Нахождение индексов ошибок\n",
|
|||
|
"error_index = y_test[y_test[\"above_average_networth\"] != y_pred].index.tolist() # Изменено на \"above_average_networth\"\n",
|
|||
|
"display(f\"Error items count: {len(error_index)}\")\n",
|
|||
|
"\n",
|
|||
|
"# Создание DataFrame с ошибочными объектами\n",
|
|||
|
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
|
|||
|
"error_df = X_test.loc[error_index].copy()\n",
|
|||
|
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
|
|||
|
"error_df = error_df.sort_index() # Сортировка по индексу\n",
|
|||
|
"\n",
|
|||
|
"# Вывод DataFrame с ошибочными объектами\n",
|
|||
|
"display(error_df)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Пример использования обученной модели (конвейера) для предсказания"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 27,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Rank</th>\n",
|
|||
|
" <th>Name</th>\n",
|
|||
|
" <th>Networth</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Country</th>\n",
|
|||
|
" <th>Source</th>\n",
|
|||
|
" <th>Industry</th>\n",
|
|||
|
" <th>above_average_networth</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1465</th>\n",
|
|||
|
" <td>1445</td>\n",
|
|||
|
" <td>Gordon Getty</td>\n",
|
|||
|
" <td>2.1</td>\n",
|
|||
|
" <td>88</td>\n",
|
|||
|
" <td>United States</td>\n",
|
|||
|
" <td>Getty Oil</td>\n",
|
|||
|
" <td>Energy</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Rank Name Networth Age Country Source Industry \\\n",
|
|||
|
"1465 1445 Gordon Getty 2.1 88 United States Getty Oil Energy \n",
|
|||
|
"\n",
|
|||
|
" above_average_networth \n",
|
|||
|
"1465 0 "
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Networth</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Country_Argentina</th>\n",
|
|||
|
" <th>Country_Australia</th>\n",
|
|||
|
" <th>Country_Austria</th>\n",
|
|||
|
" <th>Country_Barbados</th>\n",
|
|||
|
" <th>Country_Belgium</th>\n",
|
|||
|
" <th>Country_Belize</th>\n",
|
|||
|
" <th>Country_Brazil</th>\n",
|
|||
|
" <th>Country_Bulgaria</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>Industry_Manufacturing</th>\n",
|
|||
|
" <th>Industry_Media & Entertainment</th>\n",
|
|||
|
" <th>Industry_Metals & Mining</th>\n",
|
|||
|
" <th>Industry_Real Estate</th>\n",
|
|||
|
" <th>Industry_Service</th>\n",
|
|||
|
" <th>Industry_Sports</th>\n",
|
|||
|
" <th>Industry_Technology</th>\n",
|
|||
|
" <th>Industry_Telecom</th>\n",
|
|||
|
" <th>Industry_diversified</th>\n",
|
|||
|
" <th>above_average_networth</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1465</th>\n",
|
|||
|
" <td>-0.271312</td>\n",
|
|||
|
" <td>1.834522</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>1 rows × 859 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Networth Age Country_Argentina Country_Australia \\\n",
|
|||
|
"1465 -0.271312 1.834522 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Country_Austria Country_Barbados Country_Belgium Country_Belize \\\n",
|
|||
|
"1465 0.0 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Country_Brazil Country_Bulgaria ... Industry_Manufacturing \\\n",
|
|||
|
"1465 0.0 0.0 ... 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Media & Entertainment Industry_Metals & Mining \\\n",
|
|||
|
"1465 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Real Estate Industry_Service Industry_Sports \\\n",
|
|||
|
"1465 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Technology Industry_Telecom Industry_diversified \\\n",
|
|||
|
"1465 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" above_average_networth \n",
|
|||
|
"1465 0.0 \n",
|
|||
|
"\n",
|
|||
|
"[1 rows x 859 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"predicted: 0 (proba: [0.99415059 0.00584941])\n",
|
|||
|
"real: 0\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Выбираем лучшую модель\n",
|
|||
|
"model = class_models[best_model][\"pipeline\"]\n",
|
|||
|
"\n",
|
|||
|
"# Выбираем позиционный индекс объекта для анализа\n",
|
|||
|
"example_index = 13\n",
|
|||
|
"\n",
|
|||
|
"# Получаем исходные данные для объекта\n",
|
|||
|
"test = pd.DataFrame(X_test.iloc[example_index, :]).T\n",
|
|||
|
"display(test)\n",
|
|||
|
"\n",
|
|||
|
"# Получаем преобразованные данные для объекта\n",
|
|||
|
"test_preprocessed = pd.DataFrame(preprocessed_df.iloc[example_index, :]).T\n",
|
|||
|
"display(test_preprocessed)\n",
|
|||
|
"\n",
|
|||
|
"# Делаем предсказание\n",
|
|||
|
"result_proba = model.predict_proba(test)[0]\n",
|
|||
|
"result = model.predict(test)[0]\n",
|
|||
|
"\n",
|
|||
|
"# Получаем реальное значение\n",
|
|||
|
"real = int(y_test.iloc[example_index].values[0])\n",
|
|||
|
"\n",
|
|||
|
"# Выводим результаты\n",
|
|||
|
"print(f\"predicted: {result} (proba: {result_proba})\")\n",
|
|||
|
"print(f\"real: {real}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Подбор гиперпараметров методом поиска по сетке"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 28,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
|
|||
|
" _data = np.array(data, dtype=dtype, copy=copy,\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"{'model__criterion': 'gini',\n",
|
|||
|
" 'model__max_depth': 5,\n",
|
|||
|
" 'model__max_features': 'sqrt',\n",
|
|||
|
" 'model__n_estimators': 50}"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 28,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.model_selection import GridSearchCV\n",
|
|||
|
"\n",
|
|||
|
"optimized_model_type = \"random_forest\"\n",
|
|||
|
"\n",
|
|||
|
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
|
|||
|
"\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" \"model__n_estimators\": [10, 50, 100],\n",
|
|||
|
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
|
|||
|
" \"model__max_depth\": [5, 7, 10],\n",
|
|||
|
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"gs_optomizer = GridSearchCV(\n",
|
|||
|
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
|
|||
|
")\n",
|
|||
|
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
|
|||
|
"gs_optomizer.best_params_"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"__Обучение модели с новыми гиперпараметрами__"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 29,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.ensemble import RandomForestClassifier\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Определяем числовые признаки\n",
|
|||
|
"numeric_features = X_train.select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
|
|||
|
"\n",
|
|||
|
"# Установка random_state\n",
|
|||
|
"random_state = 42\n",
|
|||
|
"\n",
|
|||
|
"# Определение трансформера\n",
|
|||
|
"pipeline_end = ColumnTransformer([\n",
|
|||
|
" ('numeric', StandardScaler(), numeric_features),\n",
|
|||
|
" # Добавьте другие трансформеры, если требуется\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Объявление модели\n",
|
|||
|
"optimized_model = RandomForestClassifier(\n",
|
|||
|
" random_state=random_state,\n",
|
|||
|
" criterion=\"gini\",\n",
|
|||
|
" max_depth=5,\n",
|
|||
|
" max_features=\"sqrt\",\n",
|
|||
|
" n_estimators=10,\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Создание пайплайна с корректными шагами\n",
|
|||
|
"result = {}\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели\n",
|
|||
|
"result[\"pipeline\"] = Pipeline([\n",
|
|||
|
" (\"pipeline\", pipeline_end),\n",
|
|||
|
" (\"model\", optimized_model)\n",
|
|||
|
"]).fit(X_train, y_train.values.ravel())\n",
|
|||
|
"\n",
|
|||
|
"# Прогнозирование и расчет метрик\n",
|
|||
|
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
|
|||
|
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
|
|||
|
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
|
|||
|
"\n",
|
|||
|
"# Метрики для оценки модели\n",
|
|||
|
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
|
|||
|
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Формирование данных для оценки старой и новой версии модели"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 30,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
|
|||
|
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
|||
|
" data=class_models[optimized_model_type]\n",
|
|||
|
")\n",
|
|||
|
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
|||
|
" data=result\n",
|
|||
|
")\n",
|
|||
|
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
|
|||
|
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Оценка параметров старой и новой модели"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 31,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_bd760_row0_col0, #T_bd760_row0_col1, #T_bd760_row0_col2, #T_bd760_row0_col3, #T_bd760_row1_col0, #T_bd760_row1_col1, #T_bd760_row1_col2, #T_bd760_row1_col3 {\n",
|
|||
|
" background-color: #440154;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_bd760_row0_col4, #T_bd760_row0_col5, #T_bd760_row0_col6, #T_bd760_row0_col7, #T_bd760_row1_col4, #T_bd760_row1_col5, #T_bd760_row1_col6, #T_bd760_row1_col7 {\n",
|
|||
|
" background-color: #0d0887;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_bd760\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_bd760_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
|||
|
" <th id=\"T_bd760_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
|||
|
" <th id=\"T_bd760_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
|||
|
" <th id=\"T_bd760_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
|||
|
" <th id=\"T_bd760_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
|||
|
" <th id=\"T_bd760_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_bd760_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
|||
|
" <th id=\"T_bd760_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"index_name level0\" >Name</th>\n",
|
|||
|
" <th class=\"blank col0\" > </th>\n",
|
|||
|
" <th class=\"blank col1\" > </th>\n",
|
|||
|
" <th class=\"blank col2\" > </th>\n",
|
|||
|
" <th class=\"blank col3\" > </th>\n",
|
|||
|
" <th class=\"blank col4\" > </th>\n",
|
|||
|
" <th class=\"blank col5\" > </th>\n",
|
|||
|
" <th class=\"blank col6\" > </th>\n",
|
|||
|
" <th class=\"blank col7\" > </th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_bd760_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
|||
|
" <td id=\"T_bd760_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_bd760_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_bd760_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_bd760_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_bd760_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_bd760_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_bd760_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_bd760_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_bd760_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
|||
|
" <td id=\"T_bd760_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_bd760_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_bd760_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_bd760_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_bd760_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_bd760_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_bd760_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_bd760_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x29dc51d7f80>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 31,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"optimized_metrics[\n",
|
|||
|
" [\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" \"Accuracy_train\",\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_train\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"].style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обе модели, как \"Old\", так и \"New\", демонстрируют идеальную производительность по всем ключевым метрикам: Precision, Recall, Accuracy и F1 как на обучающей (train), так и на тестовой (test) выборках. Все значения равны 1.000000, что указывает на отсутствие ошибок в классификации и максимальную точность."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 32,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_84894_row0_col0, #T_84894_row0_col1, #T_84894_row1_col0, #T_84894_row1_col1 {\n",
|
|||
|
" background-color: #440154;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_84894_row0_col2, #T_84894_row0_col3, #T_84894_row0_col4, #T_84894_row1_col2, #T_84894_row1_col3, #T_84894_row1_col4 {\n",
|
|||
|
" background-color: #0d0887;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_84894\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_84894_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_84894_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
|||
|
" <th id=\"T_84894_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
|||
|
" <th id=\"T_84894_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
|||
|
" <th id=\"T_84894_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"index_name level0\" >Name</th>\n",
|
|||
|
" <th class=\"blank col0\" > </th>\n",
|
|||
|
" <th class=\"blank col1\" > </th>\n",
|
|||
|
" <th class=\"blank col2\" > </th>\n",
|
|||
|
" <th class=\"blank col3\" > </th>\n",
|
|||
|
" <th class=\"blank col4\" > </th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_84894_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
|||
|
" <td id=\"T_84894_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_84894_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_84894_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_84894_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_84894_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_84894_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
|||
|
" <td id=\"T_84894_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_84894_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_84894_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_84894_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_84894_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x29dc51d7d70>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 32,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"optimized_metrics[\n",
|
|||
|
" [\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"].style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" ],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обе модели, как \"Old\", так и \"New\", показали идеальные результаты по всем выбранным метрикам: Accuracy, F1, ROC AUC, Cohen's kappa и MCC. Все метрики имеют значение 1.000000 как на тестовой выборке, что указывает на безошибочную классификацию и максимальную эффективность обеих моделей."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 33,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA6cAAAGvCAYAAACn0KM0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABrzklEQVR4nO3de3yO9ePH8fe982y7N3PYjDkzk1MkrRxzGCGikAkl/fJFpUgpZ1qR6huFoqTmS0UHFCGH0b4ROYSElBVDYTOHne7r94fvrrpz2uze7vue1/PxuB65r+tzf67PdVt7+1yfz/W5LYZhGAIAAAAAwIk8nN0AAAAAAADonAIAAAAAnI7OKQAAAADA6eicAgAAAACcjs4pAAAAAMDp6JwCAAAAAJyOzikAAAAAwOnonAIAAAAAnI7OKQAAAADA6eicAgAAAACcjs4pAADX6cUXX5TFYtETTzxh7rtw4YIGDx6sUqVKKTAwUN27d9exY8fs3nf48GF17NhRJUqUUNmyZTVixAhlZ2cXcesBALDn7FzzKugFAABQ1C5cuKDMzEyH1efj4yM/P798vWfLli2aPXu26tWrZ7d/2LBhWr58uT766CMFBwdryJAh6tatmzZt2iRJysnJUceOHRUeHq5vvvlGR48eVd++feXt7a0XXnjBYdcEAHAf5Nr/GAAAuJHz588b4WU9DUkO28LDw43z58/nuQ1nzpwxatSoYaxatcpo0aKF8fjjjxuGYRinT582vL29jY8++sgsu3fvXkOSkZSUZBiGYXzxxReGh4eHkZKSYpaZOXOmYbVajYyMDMd8SAAAt0Gu/YWRUwCAW8nMzFTK8Rwd2lpJ1qCCP52SdsamKo1+1R9//CGr1Wru9/X1la+v72XfM3jwYHXs2FFt2rTRpEmTzP1bt25VVlaW2rRpY+6rVauWKlasqKSkJN12221KSkpS3bp1FRYWZpaJjY3VoEGDtHv3bt18880FviYAgPsg1/5C5xQA4JasQR4OCfFckZGRdq/Hjh2rcePGXVJu4cKF2rZtm7Zs2XLJsZSUFPn4+CgkJMRuf1hYmFJSUswyfw/w3OO5xwAANyZyjc4pAMBN5Rg25RiOqUeSkpOTL7nD/E/Jycl6/PHHtWrVqnw/ywMAwNWQa6zWCwBwUzYZDtskyWq12m2XC/GtW7fq+PHjatiwoby8vOTl5aX169fr9ddfl5eXl8LCwpSZmanTp0/bve/YsWMKDw+XJIWHh1+yymHu69wyAIAbD7lG5xQAgDxr3bq1du3ape3bt5vbLbfcori4OPPP3t7eWrNmjfmeffv26fDhw4qJiZEkxcTEaNeuXTp+/LhZZtWqVbJarapdu3aRXxMA4MblarnGtF4AgFuyySabg+rJq6CgINWpU8duX0BAgEqVKmXuHzBggJ588kmFhobKarVq6NChiomJ0W233SZJateunWrXrq0HHnhAU6ZMUUpKip5//nkNHjz4igtVAACKP3KNzikAwE3lGIZyjII/nOOIOv7u1VdflYeHh7p3766MjAzFxsbqzTffNI97enpq2bJlGjRokGJiYhQQEKB+/fppwoQJDm0HAMC9kGuSxTAc3HoAAApRWlqagoODlfxjeYctuR9Z63elpqbaLRwBAEBRINf+wsgpAMAt/X3Rh4LWAwCAs5FrdE4BAG7KJkM5N3iIAwCKD3KN1XoBAAAAAC6AkVMAgFti+hMAoDgh1xg5BQAAAAC4AEZOAQBuyVWX3AcA4HqQa3ROAQBuyva/zRH1AADgbOQa03oBAAAAAC6AkVMAgFvKcdCS+46oAwCAgiLX6JwCANxUjnFxc0Q9AAA4G7nGtF4AAAAAgAtg5BQA4JZYOAIAUJyQa3ROAQBuyiaLcmRxSD0AADgbuca0XgAAAACAC2DkFADglmzGxc0R9QAA4GzkGiOnAAAAAAAXwMgpAMAt5Tjo2RxH1AEAQEGRa3ROAQBuihAHABQn5BrTegEAAAAALoCRUwCAW7IZFtkMByy574A6AAAoKHKNzikAwE0x/QkAUJyQa0zrBQAAAAC4AEZOAQBuKUceynHAPdYcB7QFAICCItfonAIA3JThoGdzDDd+NgcAUHyQa0zrBQAAAAC4AEZOAQBuiYUjAADFCblG5xQA4KZyDA/lGA54NsdwQGMAACggco1pvQAAAAAAF8DIKQDALdlkkc0B91htcuNbzACAYoNcY+QUAAAAAOACGDkFALglFo4AABQn5BqdUwCAm3LcwhHuO/0JAFB8kGtM6wUAAAAAuAA6pwAKxbx582SxWPTLL79cs2zlypXVv3//Qm8TipeLC0c4ZgMAwNnINTqnAPJp9+7d6tOnj8qXLy9fX19FREQoLi5Ou3fvdnbTcIOxyUM5DtgcsTIiANeXe9PUz89Pv//++yXHW7ZsqTp16jihZcBF5BqdUwD5sGTJEjVs2FBr1qzRgw8+qDfffFMDBgzQ2rVr1bBhQ33yySfObiIAAFeVkZGhF1980dnNAHAZLIgEIE8OHjyoBx54QFWrVtWGDRtUpkwZ89jjjz+uZs2a6YEHHtDOnTtVtWpVJ7YUNwoWjgBwPRo0aKC3335bzz77rCIiIpzdHMBErjFyCiCPpk6dqnPnzumtt96y65hKUunSpTV79mydPXtWU6ZMuWIdhmFo0qRJqlChgkqUKKFWrVoxHRjXzfa/qUuO2ADcOEaNGqWcnJw8jZ5+8MEHatSokfz9/RUaGqpevXopOTnZPP7666/L09NTp0+fNvdNmzZNFotFTz75pLkvJydHQUFBGjlypEOvBcULuUbnFEAeLV26VJUrV1azZs0ue7x58+aqXLmyli9ffsU6xowZo9GjR6t+/fqaOnWqqlatqnbt2uns2bOF1WwAAOxUqVJFffv21dtvv60jR45csdzkyZPVt29f1ahRQ6+88oqeeOIJrVmzRs2bNzc7o82aNZPNZtPGjRvN9yUmJsrDw0OJiYnmvu+//17p6elq3rx5oV0XUBzQOQVwTampqTpy5Ijq169/1XL16tXTb7/9pjNnzlxy7MSJE5oyZYo6duyoZcuWafDgwZo7d6769++vP/74o7CajmIsx7A4bMuPmTNnql69erJarbJarYqJidGXX35pHm/ZsqUsFovd9uijj9rVcfjwYXXs2FElSpRQ2bJlNWLECGVnZzvkcwFwbc8995yys7P10ksvXfb4r7/+qrFjx2rSpElauHChBg0apDFjxmjt2rX67bff9Oabb0qS6tevL6vVanZEDcPQxo0b1b17d7NDKv3VYb3jjjuK5gLhlsg1OqcA8iC3sxkUFHTVcrnH09LSLjm2evVqZWZmaujQobJY/vql+cQTTziuoUARqFChgl588UVt3bpV3333ne6880516dLFbor6wIEDdfToUXP7+3T3nJwcdezYUZmZmfrmm2/03nvvad68eRozZowzLge4IVWtWlUPPPCA3nrrLR09evSS40uWLJHNZlOPHj30xx9/mFt4eLhq1KihtWvXSpI8PDx0++23a8OGDZKkvXv36s8//9QzzzwjwzCUlJQk6WLntE6dOgoJCSmyawTyypVyjc4pgGvK7XRebkT0767Wif31118lSTVq1LDbX6ZMGZUsWdIRzcQNxhHL7edu+dG5c2fdddddqlGjhmrWrKnJkycrMDBQ//3vf80yJUqUUHh4uLlZrVbz2FdffaU9e/bogw8+UIMGDdShQwdNnDhRb7zxhjIzMx32+QC4uueff17Z2dmXffZ0//79MgxDNWrUUJkyZey2vXv36vjx42bZZs2aaevWrTp//rwSExNVrlw5NWzYUPXr1zdHVDdu3HjFx2KAXOQanVMAeRAcHKxy5cpp586dVy23c+dOlS9f3u4XFlBYbIaHwzbp4oj/37eMjIxrtiEnJ0cLFy7U2bNnFRMTY+5PSEhQ6dKlVadOHT377LM6d+6ceSwpKUl169ZVWFiYuS82NlZpaWksEAYUoapVq6pPnz6XHT212WyyWCxasWKFVq1adck2e/Zss2zTpk2VlZWlpKQkJSYmmp3QZs2aKTExUT/++KNOnDhB5xTXRK7
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1000x400 with 4 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False)\n",
|
|||
|
"\n",
|
|||
|
"# Предполагается, что optimized_metrics - DataFrame с матрицами ошибок\n",
|
|||
|
"for index in range(0, len(optimized_metrics)):\n",
|
|||
|
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[\"Below Average\", \"Above Average\"] # Измените метки на нужные\n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
" disp.ax_.set_title(optimized_metrics.index[index]) # Заголовок с названием модели\n",
|
|||
|
"\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"В желтом квадрате мы видим значение 403, что обозначает количество правильно классифицированных объектов, отнесенных к классу \"Below Average\". Это свидетельствует о том, что модель успешно идентифицирует объекты этого класса, минимизируя количество ложных положительных срабатываний.\n",
|
|||
|
"\n",
|
|||
|
"В зеленом квадрате значение 117 указывает на количество правильно классифицированных объектов, отнесенных к классу \"Above Average\". Это также является показателем высокой точности модели в определении объектов данного класса."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Определение достижимого уровня качества модели для второй задачи "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"__Подготовка данных__"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Загрузка данных и создание целевой переменной"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 40,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Среднее значение поля 'Networth': 4.8607499999999995\n",
|
|||
|
" Rank Name Networth Age Country \\\n",
|
|||
|
"0 1 Elon Musk 219.0 50 United States \n",
|
|||
|
"1 2 Jeff Bezos 171.0 58 United States \n",
|
|||
|
"2 3 Bernard Arnault & family 158.0 73 France \n",
|
|||
|
"3 4 Bill Gates 129.0 66 United States \n",
|
|||
|
"4 5 Warren Buffett 118.0 91 United States \n",
|
|||
|
"\n",
|
|||
|
" Source Industry above_average_networth \n",
|
|||
|
"0 Tesla, SpaceX Automotive 1 \n",
|
|||
|
"1 Amazon Technology 1 \n",
|
|||
|
"2 LVMH Fashion & Retail 1 \n",
|
|||
|
"3 Microsoft Technology 1 \n",
|
|||
|
"4 Berkshire Hathaway Finance & Investments 1 \n",
|
|||
|
"Статистическое описание DataFrame:\n",
|
|||
|
" Rank Networth Age above_average_networth\n",
|
|||
|
"count 2600.000000 2600.000000 2600.000000 2600.000000\n",
|
|||
|
"mean 1269.570769 4.860750 64.271923 0.225000\n",
|
|||
|
"std 728.146364 10.659671 13.220607 0.417663\n",
|
|||
|
"min 1.000000 1.000000 19.000000 0.000000\n",
|
|||
|
"25% 637.000000 1.500000 55.000000 0.000000\n",
|
|||
|
"50% 1292.000000 2.400000 64.000000 0.000000\n",
|
|||
|
"75% 1929.000000 4.500000 74.000000 0.000000\n",
|
|||
|
"max 2578.000000 219.000000 100.000000 1.000000\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"from sklearn import set_config\n",
|
|||
|
"\n",
|
|||
|
"set_config(transform_output=\"pandas\")\n",
|
|||
|
"\n",
|
|||
|
"df = pd.read_csv(\"..//static//csv//Forbes Billionaires.csv\")\n",
|
|||
|
"\n",
|
|||
|
"random_state = 42\n",
|
|||
|
"\n",
|
|||
|
"# Вычисление среднего значения поля \"Networth\"\n",
|
|||
|
"average_networth = df['Networth'].mean()\n",
|
|||
|
"print(f\"Среднее значение поля 'Networth': {average_networth}\")\n",
|
|||
|
"\n",
|
|||
|
"# Создание новой колонки, указывающей, выше или ниже среднего значение чистого состояния\n",
|
|||
|
"df['above_average_networth'] = (df['Networth'] > average_networth).astype(int)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод DataFrame с новой колонкой\n",
|
|||
|
"print(df.head())\n",
|
|||
|
"\n",
|
|||
|
"# Примерный анализ данных\n",
|
|||
|
"print(\"Статистическое описание DataFrame:\")\n",
|
|||
|
"print(df.describe())\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи регрессии\n",
|
|||
|
"\n",
|
|||
|
"Целевой признак -- above_average_close"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 42,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Rank</th>\n",
|
|||
|
" <th>Name</th>\n",
|
|||
|
" <th>Networth</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Country</th>\n",
|
|||
|
" <th>Source</th>\n",
|
|||
|
" <th>Industry</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>582</th>\n",
|
|||
|
" <td>579</td>\n",
|
|||
|
" <td>Alexandra Schoerghuber & family</td>\n",
|
|||
|
" <td>4.9</td>\n",
|
|||
|
" <td>63</td>\n",
|
|||
|
" <td>Germany</td>\n",
|
|||
|
" <td>real estate</td>\n",
|
|||
|
" <td>Real Estate</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>48</th>\n",
|
|||
|
" <td>49</td>\n",
|
|||
|
" <td>He Xiangjian</td>\n",
|
|||
|
" <td>28.3</td>\n",
|
|||
|
" <td>79</td>\n",
|
|||
|
" <td>China</td>\n",
|
|||
|
" <td>home appliances</td>\n",
|
|||
|
" <td>Manufacturing</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1772</th>\n",
|
|||
|
" <td>1729</td>\n",
|
|||
|
" <td>Bruce Mathieson</td>\n",
|
|||
|
" <td>1.7</td>\n",
|
|||
|
" <td>78</td>\n",
|
|||
|
" <td>Australia</td>\n",
|
|||
|
" <td>hotels</td>\n",
|
|||
|
" <td>Food & Beverage</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>964</th>\n",
|
|||
|
" <td>951</td>\n",
|
|||
|
" <td>Pansy Ho</td>\n",
|
|||
|
" <td>3.2</td>\n",
|
|||
|
" <td>59</td>\n",
|
|||
|
" <td>Hong Kong</td>\n",
|
|||
|
" <td>casinos</td>\n",
|
|||
|
" <td>Gambling & Casinos</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2213</th>\n",
|
|||
|
" <td>2190</td>\n",
|
|||
|
" <td>Sasson Dayan & family</td>\n",
|
|||
|
" <td>1.3</td>\n",
|
|||
|
" <td>82</td>\n",
|
|||
|
" <td>Brazil</td>\n",
|
|||
|
" <td>banking</td>\n",
|
|||
|
" <td>Finance & Investments</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1638</th>\n",
|
|||
|
" <td>1579</td>\n",
|
|||
|
" <td>Wang Chou-hsiong</td>\n",
|
|||
|
" <td>1.9</td>\n",
|
|||
|
" <td>81</td>\n",
|
|||
|
" <td>Taiwan</td>\n",
|
|||
|
" <td>footwear</td>\n",
|
|||
|
" <td>Manufacturing</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1095</th>\n",
|
|||
|
" <td>1096</td>\n",
|
|||
|
" <td>Jose Joao Abdalla Filho</td>\n",
|
|||
|
" <td>2.8</td>\n",
|
|||
|
" <td>76</td>\n",
|
|||
|
" <td>Brazil</td>\n",
|
|||
|
" <td>investments</td>\n",
|
|||
|
" <td>Finance & Investments</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1130</th>\n",
|
|||
|
" <td>1096</td>\n",
|
|||
|
" <td>Lin Chen-hai</td>\n",
|
|||
|
" <td>2.8</td>\n",
|
|||
|
" <td>75</td>\n",
|
|||
|
" <td>Taiwan</td>\n",
|
|||
|
" <td>real estate</td>\n",
|
|||
|
" <td>Real Estate</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1294</th>\n",
|
|||
|
" <td>1292</td>\n",
|
|||
|
" <td>Banwari Lal Bawri</td>\n",
|
|||
|
" <td>2.4</td>\n",
|
|||
|
" <td>69</td>\n",
|
|||
|
" <td>India</td>\n",
|
|||
|
" <td>pharmaceuticals</td>\n",
|
|||
|
" <td>Healthcare</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>860</th>\n",
|
|||
|
" <td>851</td>\n",
|
|||
|
" <td>Kuok Khoon Hong</td>\n",
|
|||
|
" <td>3.5</td>\n",
|
|||
|
" <td>72</td>\n",
|
|||
|
" <td>Singapore</td>\n",
|
|||
|
" <td>palm oil</td>\n",
|
|||
|
" <td>Manufacturing</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>2080 rows × 7 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Rank Name Networth Age Country \\\n",
|
|||
|
"582 579 Alexandra Schoerghuber & family 4.9 63 Germany \n",
|
|||
|
"48 49 He Xiangjian 28.3 79 China \n",
|
|||
|
"1772 1729 Bruce Mathieson 1.7 78 Australia \n",
|
|||
|
"964 951 Pansy Ho 3.2 59 Hong Kong \n",
|
|||
|
"2213 2190 Sasson Dayan & family 1.3 82 Brazil \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"1638 1579 Wang Chou-hsiong 1.9 81 Taiwan \n",
|
|||
|
"1095 1096 Jose Joao Abdalla Filho 2.8 76 Brazil \n",
|
|||
|
"1130 1096 Lin Chen-hai 2.8 75 Taiwan \n",
|
|||
|
"1294 1292 Banwari Lal Bawri 2.4 69 India \n",
|
|||
|
"860 851 Kuok Khoon Hong 3.5 72 Singapore \n",
|
|||
|
"\n",
|
|||
|
" Source Industry \n",
|
|||
|
"582 real estate Real Estate \n",
|
|||
|
"48 home appliances Manufacturing \n",
|
|||
|
"1772 hotels Food & Beverage \n",
|
|||
|
"964 casinos Gambling & Casinos \n",
|
|||
|
"2213 banking Finance & Investments \n",
|
|||
|
"... ... ... \n",
|
|||
|
"1638 footwear Manufacturing \n",
|
|||
|
"1095 investments Finance & Investments \n",
|
|||
|
"1130 real estate Real Estate \n",
|
|||
|
"1294 pharmaceuticals Healthcare \n",
|
|||
|
"860 palm oil Manufacturing \n",
|
|||
|
"\n",
|
|||
|
"[2080 rows x 7 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>above_average_networth</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>582</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>48</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1772</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>964</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2213</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1638</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1095</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1130</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1294</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>860</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>2080 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" above_average_networth\n",
|
|||
|
"582 1\n",
|
|||
|
"48 1\n",
|
|||
|
"1772 0\n",
|
|||
|
"964 0\n",
|
|||
|
"2213 0\n",
|
|||
|
"... ...\n",
|
|||
|
"1638 0\n",
|
|||
|
"1095 0\n",
|
|||
|
"1130 0\n",
|
|||
|
"1294 0\n",
|
|||
|
"860 0\n",
|
|||
|
"\n",
|
|||
|
"[2080 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Rank</th>\n",
|
|||
|
" <th>Name</th>\n",
|
|||
|
" <th>Networth</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Country</th>\n",
|
|||
|
" <th>Source</th>\n",
|
|||
|
" <th>Industry</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1593</th>\n",
|
|||
|
" <td>1579</td>\n",
|
|||
|
" <td>Guangming Fu & family</td>\n",
|
|||
|
" <td>1.9</td>\n",
|
|||
|
" <td>68</td>\n",
|
|||
|
" <td>China</td>\n",
|
|||
|
" <td>poultry</td>\n",
|
|||
|
" <td>Food & Beverage</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>196</th>\n",
|
|||
|
" <td>197</td>\n",
|
|||
|
" <td>Leon Black</td>\n",
|
|||
|
" <td>10.0</td>\n",
|
|||
|
" <td>70</td>\n",
|
|||
|
" <td>United States</td>\n",
|
|||
|
" <td>private equity</td>\n",
|
|||
|
" <td>Finance & Investments</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>239</th>\n",
|
|||
|
" <td>235</td>\n",
|
|||
|
" <td>Zong Qinghou</td>\n",
|
|||
|
" <td>8.8</td>\n",
|
|||
|
" <td>76</td>\n",
|
|||
|
" <td>China</td>\n",
|
|||
|
" <td>beverages</td>\n",
|
|||
|
" <td>Food & Beverage</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2126</th>\n",
|
|||
|
" <td>2076</td>\n",
|
|||
|
" <td>Kurt Krieger</td>\n",
|
|||
|
" <td>1.4</td>\n",
|
|||
|
" <td>74</td>\n",
|
|||
|
" <td>Germany</td>\n",
|
|||
|
" <td>furniture retailing</td>\n",
|
|||
|
" <td>Fashion & Retail</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1587</th>\n",
|
|||
|
" <td>1579</td>\n",
|
|||
|
" <td>Chen Kaichen</td>\n",
|
|||
|
" <td>1.9</td>\n",
|
|||
|
" <td>64</td>\n",
|
|||
|
" <td>China</td>\n",
|
|||
|
" <td>household chemicals</td>\n",
|
|||
|
" <td>Manufacturing</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1778</th>\n",
|
|||
|
" <td>1729</td>\n",
|
|||
|
" <td>Jorge Perez</td>\n",
|
|||
|
" <td>1.7</td>\n",
|
|||
|
" <td>72</td>\n",
|
|||
|
" <td>United States</td>\n",
|
|||
|
" <td>real estate</td>\n",
|
|||
|
" <td>Real Estate</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>166</th>\n",
|
|||
|
" <td>167</td>\n",
|
|||
|
" <td>Brian Chesky</td>\n",
|
|||
|
" <td>11.5</td>\n",
|
|||
|
" <td>40</td>\n",
|
|||
|
" <td>United States</td>\n",
|
|||
|
" <td>Airbnb</td>\n",
|
|||
|
" <td>Technology</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>949</th>\n",
|
|||
|
" <td>913</td>\n",
|
|||
|
" <td>Zhong Ruonong & family</td>\n",
|
|||
|
" <td>3.3</td>\n",
|
|||
|
" <td>59</td>\n",
|
|||
|
" <td>China</td>\n",
|
|||
|
" <td>electronics</td>\n",
|
|||
|
" <td>Manufacturing</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>49</th>\n",
|
|||
|
" <td>50</td>\n",
|
|||
|
" <td>Miriam Adelson</td>\n",
|
|||
|
" <td>27.5</td>\n",
|
|||
|
" <td>76</td>\n",
|
|||
|
" <td>United States</td>\n",
|
|||
|
" <td>casinos</td>\n",
|
|||
|
" <td>Gambling & Casinos</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2511</th>\n",
|
|||
|
" <td>2448</td>\n",
|
|||
|
" <td>Lou Boliang</td>\n",
|
|||
|
" <td>1.1</td>\n",
|
|||
|
" <td>58</td>\n",
|
|||
|
" <td>United States</td>\n",
|
|||
|
" <td>pharmaceuticals</td>\n",
|
|||
|
" <td>Healthcare</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>520 rows × 7 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Rank Name Networth Age Country \\\n",
|
|||
|
"1593 1579 Guangming Fu & family 1.9 68 China \n",
|
|||
|
"196 197 Leon Black 10.0 70 United States \n",
|
|||
|
"239 235 Zong Qinghou 8.8 76 China \n",
|
|||
|
"2126 2076 Kurt Krieger 1.4 74 Germany \n",
|
|||
|
"1587 1579 Chen Kaichen 1.9 64 China \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"1778 1729 Jorge Perez 1.7 72 United States \n",
|
|||
|
"166 167 Brian Chesky 11.5 40 United States \n",
|
|||
|
"949 913 Zhong Ruonong & family 3.3 59 China \n",
|
|||
|
"49 50 Miriam Adelson 27.5 76 United States \n",
|
|||
|
"2511 2448 Lou Boliang 1.1 58 United States \n",
|
|||
|
"\n",
|
|||
|
" Source Industry \n",
|
|||
|
"1593 poultry Food & Beverage \n",
|
|||
|
"196 private equity Finance & Investments \n",
|
|||
|
"239 beverages Food & Beverage \n",
|
|||
|
"2126 furniture retailing Fashion & Retail \n",
|
|||
|
"1587 household chemicals Manufacturing \n",
|
|||
|
"... ... ... \n",
|
|||
|
"1778 real estate Real Estate \n",
|
|||
|
"166 Airbnb Technology \n",
|
|||
|
"949 electronics Manufacturing \n",
|
|||
|
"49 casinos Gambling & Casinos \n",
|
|||
|
"2511 pharmaceuticals Healthcare \n",
|
|||
|
"\n",
|
|||
|
"[520 rows x 7 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>above_average_networth</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1593</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>196</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>239</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2126</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1587</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1778</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>166</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>949</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>49</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2511</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>520 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" above_average_networth\n",
|
|||
|
"1593 0\n",
|
|||
|
"196 1\n",
|
|||
|
"239 1\n",
|
|||
|
"2126 0\n",
|
|||
|
"1587 0\n",
|
|||
|
"... ...\n",
|
|||
|
"1778 0\n",
|
|||
|
"166 1\n",
|
|||
|
"949 0\n",
|
|||
|
"49 1\n",
|
|||
|
"2511 0\n",
|
|||
|
"\n",
|
|||
|
"[520 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from typing import Tuple\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"from pandas import DataFrame\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"\n",
|
|||
|
"def split_into_train_test(\n",
|
|||
|
" df_input: DataFrame,\n",
|
|||
|
" target_colname: str = \"above_average_networth\", \n",
|
|||
|
" frac_train: float = 0.8,\n",
|
|||
|
" random_state: int = None,\n",
|
|||
|
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
|||
|
" \n",
|
|||
|
" if not (0 < frac_train < 1):\n",
|
|||
|
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
|
|||
|
" \n",
|
|||
|
" # Проверка наличия целевого признака\n",
|
|||
|
" if target_colname not in df_input.columns:\n",
|
|||
|
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
|
|||
|
" \n",
|
|||
|
" # Разделяем данные на признаки и целевую переменную\n",
|
|||
|
" X = df_input.drop(columns=[target_colname]) # Признаки\n",
|
|||
|
" y = df_input[[target_colname]] # Целевая переменная\n",
|
|||
|
"\n",
|
|||
|
" # Разделяем данные на обучающую и тестовую выборки\n",
|
|||
|
" X_train, X_test, y_train, y_test = train_test_split(\n",
|
|||
|
" X, y,\n",
|
|||
|
" test_size=(1.0 - frac_train),\n",
|
|||
|
" random_state=random_state\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" return X_train, X_test, y_train, y_test\n",
|
|||
|
"\n",
|
|||
|
"# Применение функции для разделения данных\n",
|
|||
|
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
|
|||
|
" df, \n",
|
|||
|
" target_colname=\"above_average_networth\", \n",
|
|||
|
" frac_train=0.8, \n",
|
|||
|
" random_state=42 \n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Для отображения результатов\n",
|
|||
|
"display(\"X_train\", X_train)\n",
|
|||
|
"display(\"y_train\", y_train)\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_test\", X_test)\n",
|
|||
|
"display(\"y_test\", y_test)\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Формирование конвейера для классификации данных\n",
|
|||
|
"\n",
|
|||
|
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
|
|||
|
"\n",
|
|||
|
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
|
|||
|
"\n",
|
|||
|
"features_preprocessing -- трансформер для предобработки признаков\n",
|
|||
|
"\n",
|
|||
|
"features_engineering -- трансформер для конструирования признаков\n",
|
|||
|
"\n",
|
|||
|
"drop_columns -- трансформер для удаления колонок\n",
|
|||
|
"\n",
|
|||
|
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 44,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Networth Age Country_Argentina Country_Australia \\\n",
|
|||
|
"0 20.092595 -1.079729 0.0 0.0 \n",
|
|||
|
"1 15.588775 -0.474496 0.0 0.0 \n",
|
|||
|
"2 14.368991 0.660314 0.0 0.0 \n",
|
|||
|
"3 11.647933 0.130736 0.0 0.0 \n",
|
|||
|
"4 10.615808 2.022087 0.0 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"2595 -0.362253 1.189893 0.0 0.0 \n",
|
|||
|
"2596 -0.362253 1.341201 0.0 0.0 \n",
|
|||
|
"2597 -0.362253 0.509006 0.0 0.0 \n",
|
|||
|
"2598 -0.362253 0.282044 0.0 0.0 \n",
|
|||
|
"2599 -0.362253 0.357698 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Country_Austria Country_Barbados Country_Belgium Country_Belize \\\n",
|
|||
|
"0 0.0 0.0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"2595 0.0 0.0 0.0 0.0 \n",
|
|||
|
"2596 0.0 0.0 0.0 0.0 \n",
|
|||
|
"2597 0.0 0.0 0.0 0.0 \n",
|
|||
|
"2598 0.0 0.0 0.0 0.0 \n",
|
|||
|
"2599 0.0 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Country_Brazil Country_Bulgaria ... Industry_Manufacturing \\\n",
|
|||
|
"0 0.0 0.0 ... 0.0 \n",
|
|||
|
"1 0.0 0.0 ... 0.0 \n",
|
|||
|
"2 0.0 0.0 ... 0.0 \n",
|
|||
|
"3 0.0 0.0 ... 0.0 \n",
|
|||
|
"4 0.0 0.0 ... 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"2595 0.0 0.0 ... 0.0 \n",
|
|||
|
"2596 0.0 0.0 ... 0.0 \n",
|
|||
|
"2597 0.0 0.0 ... 0.0 \n",
|
|||
|
"2598 0.0 0.0 ... 0.0 \n",
|
|||
|
"2599 0.0 0.0 ... 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Media & Entertainment Industry_Metals & Mining \\\n",
|
|||
|
"0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 \n",
|
|||
|
"... ... ... \n",
|
|||
|
"2595 0.0 0.0 \n",
|
|||
|
"2596 0.0 0.0 \n",
|
|||
|
"2597 0.0 0.0 \n",
|
|||
|
"2598 0.0 0.0 \n",
|
|||
|
"2599 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Real Estate Industry_Service Industry_Sports \\\n",
|
|||
|
"0 0.0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"2595 0.0 0.0 0.0 \n",
|
|||
|
"2596 0.0 0.0 0.0 \n",
|
|||
|
"2597 0.0 0.0 0.0 \n",
|
|||
|
"2598 0.0 0.0 0.0 \n",
|
|||
|
"2599 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Technology Industry_Telecom Industry_diversified \\\n",
|
|||
|
"0 0.0 0.0 0.0 \n",
|
|||
|
"1 1.0 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 0.0 \n",
|
|||
|
"3 1.0 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"2595 0.0 0.0 0.0 \n",
|
|||
|
"2596 0.0 0.0 0.0 \n",
|
|||
|
"2597 0.0 0.0 0.0 \n",
|
|||
|
"2598 0.0 0.0 0.0 \n",
|
|||
|
"2599 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Networth_per_Age \n",
|
|||
|
"0 -18.608929 \n",
|
|||
|
"1 -32.853309 \n",
|
|||
|
"2 21.760834 \n",
|
|||
|
"3 89.095063 \n",
|
|||
|
"4 5.249926 \n",
|
|||
|
"... ... \n",
|
|||
|
"2595 -0.304441 \n",
|
|||
|
"2596 -0.270096 \n",
|
|||
|
"2597 -0.711686 \n",
|
|||
|
"2598 -1.284383 \n",
|
|||
|
"2599 -1.012732 \n",
|
|||
|
"\n",
|
|||
|
"[2600 rows x 988 columns]\n",
|
|||
|
"(2600, 988)\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn.base import BaseEstimator, TransformerMixin\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.preprocessing import OneHotEncoder\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor \n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"from sklearn.pipeline import make_pipeline\n",
|
|||
|
"\n",
|
|||
|
"class ForbesBillionairesFeatures(BaseEstimator, TransformerMixin): \n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" pass\n",
|
|||
|
"\n",
|
|||
|
" def fit(self, X, y=None):\n",
|
|||
|
" return self\n",
|
|||
|
"\n",
|
|||
|
" def transform(self, X, y=None):\n",
|
|||
|
" X[\"Networth_per_Age\"] = X[\"Networth\"] / X[\"Age\"]\n",
|
|||
|
" return X\n",
|
|||
|
"\n",
|
|||
|
" def get_feature_names_out(self, features_in):\n",
|
|||
|
" return np.append(features_in, [\"Networth_per_Age\"], axis=0) \n",
|
|||
|
"\n",
|
|||
|
"# Определите признаки для вашей задачи\n",
|
|||
|
"columns_to_drop = [\"Rank \", \"Name\"] \n",
|
|||
|
"num_columns = [\"Networth\", \"Age\"] \n",
|
|||
|
"cat_columns = [\"Country\", \"Source\", \"Industry\"]\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование числовых признаков\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование категориальных признаков\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Формирование конвейера\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\" \n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"drop_columns = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\",\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Окончательный конвейер\n",
|
|||
|
"pipeline_end = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" (\"drop_columns\", drop_columns),\n",
|
|||
|
" (\"custom_features\", ForbesBillionairesFeatures()), \n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"df = pd.read_csv(\"..//static//csv//Forbes Billionaires.csv\")\n",
|
|||
|
"\n",
|
|||
|
"# Создаем целевой признак\n",
|
|||
|
"average_networth = df['Networth'].mean()\n",
|
|||
|
"df['above_average_networth'] = (df['Networth'] > average_networth).astype(int)\n",
|
|||
|
"\n",
|
|||
|
"# Подготовка данных\n",
|
|||
|
"X = df.drop('above_average_networth', axis=1)\n",
|
|||
|
"y = df['above_average_networth'].values.ravel() \n",
|
|||
|
"\n",
|
|||
|
"# Применение конвейера\n",
|
|||
|
"X_processed = pipeline_end.fit_transform(X)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод\n",
|
|||
|
"print(X_processed)\n",
|
|||
|
"print(X_processed.shape)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"__Демонстрация работы конвейера__"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 45,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Networth</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Country_Argentina</th>\n",
|
|||
|
" <th>Country_Australia</th>\n",
|
|||
|
" <th>Country_Austria</th>\n",
|
|||
|
" <th>Country_Barbados</th>\n",
|
|||
|
" <th>Country_Belgium</th>\n",
|
|||
|
" <th>Country_Belize</th>\n",
|
|||
|
" <th>Country_Brazil</th>\n",
|
|||
|
" <th>Country_Bulgaria</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>Industry_Manufacturing</th>\n",
|
|||
|
" <th>Industry_Media & Entertainment</th>\n",
|
|||
|
" <th>Industry_Metals & Mining</th>\n",
|
|||
|
" <th>Industry_Real Estate</th>\n",
|
|||
|
" <th>Industry_Service</th>\n",
|
|||
|
" <th>Industry_Sports</th>\n",
|
|||
|
" <th>Industry_Technology</th>\n",
|
|||
|
" <th>Industry_Telecom</th>\n",
|
|||
|
" <th>Industry_diversified</th>\n",
|
|||
|
" <th>Networth_per_Age</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>582</th>\n",
|
|||
|
" <td>-0.013606</td>\n",
|
|||
|
" <td>-0.109934</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.123766</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>48</th>\n",
|
|||
|
" <td>1.994083</td>\n",
|
|||
|
" <td>1.079079</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.847949</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1772</th>\n",
|
|||
|
" <td>-0.288162</td>\n",
|
|||
|
" <td>1.004766</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>-0.286795</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>964</th>\n",
|
|||
|
" <td>-0.159464</td>\n",
|
|||
|
" <td>-0.407187</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.391623</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2213</th>\n",
|
|||
|
" <td>-0.322481</td>\n",
|
|||
|
" <td>1.302019</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>-0.247678</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1638</th>\n",
|
|||
|
" <td>-0.271002</td>\n",
|
|||
|
" <td>1.227706</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>-0.220739</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1095</th>\n",
|
|||
|
" <td>-0.193783</td>\n",
|
|||
|
" <td>0.856139</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>-0.226346</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1130</th>\n",
|
|||
|
" <td>-0.193783</td>\n",
|
|||
|
" <td>0.781826</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>-0.247860</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1294</th>\n",
|
|||
|
" <td>-0.228103</td>\n",
|
|||
|
" <td>0.335946</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>-0.678986</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>860</th>\n",
|
|||
|
" <td>-0.133724</td>\n",
|
|||
|
" <td>0.558886</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>-0.239269</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>2080 rows × 857 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Networth Age Country_Argentina Country_Australia \\\n",
|
|||
|
"582 -0.013606 -0.109934 0.0 0.0 \n",
|
|||
|
"48 1.994083 1.079079 0.0 0.0 \n",
|
|||
|
"1772 -0.288162 1.004766 0.0 1.0 \n",
|
|||
|
"964 -0.159464 -0.407187 0.0 0.0 \n",
|
|||
|
"2213 -0.322481 1.302019 0.0 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"1638 -0.271002 1.227706 0.0 0.0 \n",
|
|||
|
"1095 -0.193783 0.856139 0.0 0.0 \n",
|
|||
|
"1130 -0.193783 0.781826 0.0 0.0 \n",
|
|||
|
"1294 -0.228103 0.335946 0.0 0.0 \n",
|
|||
|
"860 -0.133724 0.558886 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Country_Austria Country_Barbados Country_Belgium Country_Belize \\\n",
|
|||
|
"582 0.0 0.0 0.0 0.0 \n",
|
|||
|
"48 0.0 0.0 0.0 0.0 \n",
|
|||
|
"1772 0.0 0.0 0.0 0.0 \n",
|
|||
|
"964 0.0 0.0 0.0 0.0 \n",
|
|||
|
"2213 0.0 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"1638 0.0 0.0 0.0 0.0 \n",
|
|||
|
"1095 0.0 0.0 0.0 0.0 \n",
|
|||
|
"1130 0.0 0.0 0.0 0.0 \n",
|
|||
|
"1294 0.0 0.0 0.0 0.0 \n",
|
|||
|
"860 0.0 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Country_Brazil Country_Bulgaria ... Industry_Manufacturing \\\n",
|
|||
|
"582 0.0 0.0 ... 0.0 \n",
|
|||
|
"48 0.0 0.0 ... 1.0 \n",
|
|||
|
"1772 0.0 0.0 ... 0.0 \n",
|
|||
|
"964 0.0 0.0 ... 0.0 \n",
|
|||
|
"2213 1.0 0.0 ... 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"1638 0.0 0.0 ... 1.0 \n",
|
|||
|
"1095 1.0 0.0 ... 0.0 \n",
|
|||
|
"1130 0.0 0.0 ... 0.0 \n",
|
|||
|
"1294 0.0 0.0 ... 0.0 \n",
|
|||
|
"860 0.0 0.0 ... 1.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Media & Entertainment Industry_Metals & Mining \\\n",
|
|||
|
"582 0.0 0.0 \n",
|
|||
|
"48 0.0 0.0 \n",
|
|||
|
"1772 0.0 0.0 \n",
|
|||
|
"964 0.0 0.0 \n",
|
|||
|
"2213 0.0 0.0 \n",
|
|||
|
"... ... ... \n",
|
|||
|
"1638 0.0 0.0 \n",
|
|||
|
"1095 0.0 0.0 \n",
|
|||
|
"1130 0.0 0.0 \n",
|
|||
|
"1294 0.0 0.0 \n",
|
|||
|
"860 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Real Estate Industry_Service Industry_Sports \\\n",
|
|||
|
"582 1.0 0.0 0.0 \n",
|
|||
|
"48 0.0 0.0 0.0 \n",
|
|||
|
"1772 0.0 0.0 0.0 \n",
|
|||
|
"964 0.0 0.0 0.0 \n",
|
|||
|
"2213 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"1638 0.0 0.0 0.0 \n",
|
|||
|
"1095 0.0 0.0 0.0 \n",
|
|||
|
"1130 1.0 0.0 0.0 \n",
|
|||
|
"1294 0.0 0.0 0.0 \n",
|
|||
|
"860 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Technology Industry_Telecom Industry_diversified \\\n",
|
|||
|
"582 0.0 0.0 0.0 \n",
|
|||
|
"48 0.0 0.0 0.0 \n",
|
|||
|
"1772 0.0 0.0 0.0 \n",
|
|||
|
"964 0.0 0.0 0.0 \n",
|
|||
|
"2213 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"1638 0.0 0.0 0.0 \n",
|
|||
|
"1095 0.0 0.0 0.0 \n",
|
|||
|
"1130 0.0 0.0 0.0 \n",
|
|||
|
"1294 0.0 0.0 0.0 \n",
|
|||
|
"860 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Networth_per_Age \n",
|
|||
|
"582 0.123766 \n",
|
|||
|
"48 1.847949 \n",
|
|||
|
"1772 -0.286795 \n",
|
|||
|
"964 0.391623 \n",
|
|||
|
"2213 -0.247678 \n",
|
|||
|
"... ... \n",
|
|||
|
"1638 -0.220739 \n",
|
|||
|
"1095 -0.226346 \n",
|
|||
|
"1130 -0.247860 \n",
|
|||
|
"1294 -0.678986 \n",
|
|||
|
"860 -0.239269 \n",
|
|||
|
"\n",
|
|||
|
"[2080 rows x 857 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 45,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"preprocessed_df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Формирование набора моделей для классификации\n",
|
|||
|
"\n",
|
|||
|
"logistic -- логистическая регрессия\n",
|
|||
|
"\n",
|
|||
|
"ridge -- гребневая регрессия\n",
|
|||
|
"\n",
|
|||
|
"decision_tree -- дерево решений\n",
|
|||
|
"\n",
|
|||
|
"knn -- k-ближайших соседей\n",
|
|||
|
"\n",
|
|||
|
"naive_bayes -- наивный Байесовский классификатор\n",
|
|||
|
"\n",
|
|||
|
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
|
|||
|
"\n",
|
|||
|
"random_forest -- метод случайного леса (набор деревьев решений)\n",
|
|||
|
"\n",
|
|||
|
"mlp -- многослойный персептрон (нейронная сеть)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 46,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
|
|||
|
"\n",
|
|||
|
"class_models = {\n",
|
|||
|
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
|
|||
|
" \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
|
|||
|
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
|
|||
|
" \"decision_tree\": {\n",
|
|||
|
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
|
|||
|
" },\n",
|
|||
|
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
|
|||
|
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
|
|||
|
" \"gradient_boosting\": {\n",
|
|||
|
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
|
|||
|
" },\n",
|
|||
|
" \"random_forest\": {\n",
|
|||
|
" \"model\": ensemble.RandomForestClassifier(\n",
|
|||
|
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
" \"mlp\": {\n",
|
|||
|
" \"model\": neural_network.MLPClassifier(\n",
|
|||
|
" hidden_layer_sizes=(7,),\n",
|
|||
|
" max_iter=500,\n",
|
|||
|
" early_stopping=True,\n",
|
|||
|
" random_state=random_state,\n",
|
|||
|
" )\n",
|
|||
|
" },\n",
|
|||
|
"}"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Обучение моделей на обучающем наборе данных и оценка на тестовом"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 48,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: logistic\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: ridge\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: decision_tree\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: knn\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: naive_bayes\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: gradient_boosting\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: random_forest\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: mlp\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"\n",
|
|||
|
"for model_name in class_models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" model = class_models[model_name][\"model\"]\n",
|
|||
|
"\n",
|
|||
|
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
|
|||
|
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
|
|||
|
"\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train)\n",
|
|||
|
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
|
|||
|
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
|
|||
|
"\n",
|
|||
|
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
|||
|
" class_models[model_name][\"probs\"] = y_test_probs\n",
|
|||
|
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
|||
|
"\n",
|
|||
|
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
|
|||
|
" y_train, y_train_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
|
|||
|
" y_test, y_test_probs\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )\n",
|
|||
|
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
|
|||
|
" y_test, y_test_predict\n",
|
|||
|
" )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Сводная таблица оценок качества для использованных моделей классификации\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"__Матрица неточностей__"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 50,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4UAAAQ9CAYAAADu7ug2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVwU5R8H8M9yIzeIHIKIggh5kGiGqXijeWt5YeKRlZlXaV55oCKlmZpXpSia+MvUMu8z8SDvKw9EwVvxRFFUrt3n9wcxuSEr6MLuDp/36zWv2plnZ55ZlQ/f2eeZUQghBIiIiIiIiKhUMtJ1B4iIiIiIiEh3WBQSERERERGVYiwKiYiIiIiISjEWhURERERERKUYi0IiIiIiIqJSjEUhERERERFRKcaikIiIiIiIqBRjUUhERERERFSKsSgkIiIiIiIqxVgUEmlZTEwMFAoFLl++XCz7v3z5MhQKBWJiYrSyv7i4OCgUCsTFxWllf0RERHIxceJEKBSKQrVVKBSYOHFi8XaIqJiwKCQqJebPn6+1QpKIiIiI5MNE1x0goqLx8vLCs2fPYGpqWqT3zZ8/H2XLlkXv3r3V1jds2BDPnj2DmZmZFntJRERk+L766iuMGjVK190gKnYsCokMjEKhgIWFhdb2Z2RkpNX9ERERycGTJ09gZWUFExP+ukzyx+GjRCVg/vz5eOONN2Bubg53d3cMHDgQDx8+zNdu3rx5qFSpEiwtLfHWW29h7969aNSoERo1aiS1edGcwlu3bqFPnz7w8PCAubk53Nzc0L59e2leY8WKFXHmzBns3r0bCoUCCoVC2mdBcwoPHjyId999Fw4ODrCyskKNGjUwe/Zs7X4wREREeiBv7uDZs2fRo0cPODg4oH79+i+cU5iZmYlhw4bB2dkZNjY2aNeuHa5fv/7C/cbFxaF27dqwsLBA5cqV8eOPPxY4T3H58uUICgqCpaUlHB0d0a1bN1y7dq1Yzpfov3jpg6iYTZw4EREREWjWrBkGDBiAxMRELFiwAIcPH0Z8fLw0DHTBggX47LPP0KBBAwwbNgyXL19Ghw4d4ODgAA8PD43H6Ny5M86cOYNBgwahYsWKuHPnDrZv346rV6+iYsWKmDVrFgYNGgRra2uMHTsWAODi4lLg/rZv3442bdrAzc0NQ4YMgaurKxISErBhwwYMGTJEex8OERGRHnn//ffh6+uLqVOnQgiBO3fu5Gvz4YcfYvny5ejRowfq1auHP//8E61bt87X7vjx42jZsiXc3NwQEREBpVKJSZMmwdnZOV/byMhIjBs3Dl26dMGHH36Iu3fvYs6cOWjYsCGOHz8Oe3v74jhdon8JItKqJUuWCADi0qVL4s6dO8LMzEy0aNFCKJVKqc3cuXMFALF48WIhhBCZmZnCyclJ1KlTR2RnZ0vtYmJiBAAREhIirbt06ZIAIJYsWSKEEOLBgwcCgJg+fbrGfr3xxhtq+8mza9cuAUDs2rVLCCFETk6O8Pb2Fl5eXuLBgwdqbVUqVeE/CCIiIgMxYcIEAUB07979hevznDhxQgAQn376qVq7Hj16CABiwoQJ0rq2bduKMmXKiBs3bkjrLly4IExMTNT2efnyZWFsbCwiIyPV9nnq1ClhYmKSbz1RceDwUaJitGPHDmRlZWHo0KEwMvr3n1v//v1ha2uLjRs3AgCOHDmC+/fvo3///mpzF8LCwuDg4KDxGJaWljAzM0NcXBwePHjw2n0+fvw4Ll26hKFDh+a7MlnY23ITEREZok8++UTj9k2bNgEABg8erLZ+6NChaq+VSiV27NiBDh06wN3dXVrv4+ODVq1aqbX97bffoFKp0KVLF9y7d09aXF1d4evri127dr3GGREVDoePEhWjK1euAAD8/PzU1puZmaFSpUrS9rz/+vj4qLUzMTFBxYoVNR7D3Nwc33zzDb744gu4uLjg7bffRps2bdCrVy+4uroWuc/JyckAgGrVqhX5vURERIbM29tb4/YrV67AyMgIlStXVlv/35y/c+cOnj17li/XgfxZf+HCBQgh4Ovr+8JjFvVu40SvgkUhkQwMHToUbdu2xdq1a7F161aMGzcOUVFR+PPPP/Hmm2/quntEREQGwdLSssSPqVKpoFAosHnzZhgbG+fbbm1tXeJ9otKHw0eJipGXlxcAIDExUW19VlYWLl26JG3P+29SUpJau5ycHOkOoi9TuXJlfPHFF9i2bRtOnz6NrKwszJgxQ9pe2KGfeVc/T58+Xaj2REREpYWXlxdUKpU0qibPf3O+XLlysLCwyJfrQP6sr1y5MoQQ8Pb2RrNmzfItb7/9tvZPhOg/WBQSFaNmzZrBzMwM33//PYQQ0vro6GikpaVJdyurXbs2nJycsHDhQuTk5EjtYmNjXzpP8OnTp8jIyFBbV7lyZdjY2CAzM1NaZ2Vl9cLHYPxXrVq14O3tjVmzZuVr//w5EBERlTZ58wG///57tfWzZs1Se21sbIxmzZph7dq1uHnzprQ+KSkJmzdvVmvbqVMnGBsbIyIiIl/OCiFw//59LZ4B0Ytx+ChRMXJ2dsbo0aMRERGBli1bol27dkhMTMT8+fNRp04d9OzZE0DuHMOJEydi0KBBaNKkCbp06YLLly8jJiYGlStX1vgt3/nz59G0aVN06dIFAQEBMDExwe+//47bt2+jW7duUrugoCAsWLAAU6ZMgY+PD8qVK4cmTZrk25+RkREWLFiAtm3bIjAwEH369IGbmxvOnTuHM2fOYOvWrdr/oIiIiAxAYGAgunfvjvnz5yMtLQ316tXDzp07X/iN4MSJE7Ft2za88847GDBgAJRKJebOnYtq1arhxIkTUrvKlStjypQpGD16tPQ4KhsbG1y6dAm///47PvroIwwfPrwEz5JKIxaFRMVs4sSJcHZ2xty5czFs2DA4Ojrio48+wtSpU9Umj3/22WcQQmDGjBkYPnw4atasiXXr1mHw4MGwsLAocP+enp7o3r07du7ciZ9//hkmJiaoWrUqfv31V3Tu3FlqN378eFy5cgXTpk3D48ePERIS8sKiEABCQ0Oxa9cuREREYMaMGVCpVKhcuTL69++vvQ+GiIjIAC1evBjOzs6IjY3F2rVr0aRJE2zcuBGenp5q7YKCgrB582YMHz4c48aNg6enJyZNmoSEhAScO3dOre2oUaNQpUoVzJw5ExEREQBy871FixZo165diZ0blV4KwfFgRHpLpVLB2dkZnTp1wsKFC3XdHSIiInpNHTp0wJkzZ3DhwgVdd4VIwjmFRHoiIyMj31yCZcuWITU1FY0aNdJNp4iIiOiVPXv2TO31hQsXsGnTJuY66R1+U0ikJ+Li4jBs2DC8//77cHJywrFjxxAdHQ1/f38cPXoUZmZmuu4iERERFYGbmxt69+4tPZt4wYIFyMzMxPHjxwt8LiGRLnBOIZGeqFixIjw9PfH9998jNTUVjo6O6NWrF77++msWhERERAaoZcuW+N///odbt27B3NwcwcHBmDp1KgtC0jv8ppCIiIiIiKgU45xCIiIiIiKiUoxFIRERERERUSnGOYX0WlQqFW7evAkbGxuND1gnkiMhBB4/fgx3d3cYGWn3GltGRgaysrJe2s7MzEzjcyyJqPRhNlNpxmx+NSwK6bXcvHkz38NaiUqba9euwcPDQ2v7y8jIgLeXNW7dUb60raurKy5dumRw4UNExYfZTMRsLioWhfRabGxsAABXjlWErTVHI+tCxyrVdd2FUisH2diHTdK/A23JysrCrTtKJB3xhK1Nwf+uHj1Wwaf2NWRlZRlU8BBR8WI2615Hvxq67kKplSOysQ8bmc1FxKKQXkvesBRbayON/0Co+JgoTHXdhdLrn3s3F9fwLGsbBaxtCt63ChwWRkT5MZt1j9msY4LZXFQsComI9FS2UCJbw1ODsoWqBHtDREREcs1mFoVERHpKBQEVCg4eTduIiIhI++SazSwKiYj0lAoCShkGDxERkaGSazazKCQi0lPZQoVsDdliqENUiIiIDJVcs5lFIRGRnlL9s2jaTkRERCVHrtnMopCISE8pXzJERdM2IiIi0j65ZjOLQiI
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1200x1000 with 16 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
|
|||
|
"\n",
|
|||
|
"for index, key in enumerate(class_models.keys()):\n",
|
|||
|
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[\"Below Average\", \"Above Average\"] \n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
" disp.ax_.set_title(key)\n",
|
|||
|
"\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
|
|||
|
"plt.show()\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Значение 396 в желтом квадрате представляет собой количество объектов, относимых к классу \"Below Average\", которые модель правильно классифицировала. Это свидетельствует о высоком уровне точности в идентификации этого класса.\n",
|
|||
|
"Значение 124 в голубом квадрате указывает на количество правильно классифицированных объектов класса \"Above Average\". Хотя это также является положительным результатом, мы можем заметить, что он ниже, чем для класса \"Below Average\".\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"__Точность, полнота, верность (аккуратность), F-мера__"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 51,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_aeae3_row0_col0, #T_aeae3_row0_col1, #T_aeae3_row0_col2, #T_aeae3_row0_col3, #T_aeae3_row1_col0, #T_aeae3_row1_col1, #T_aeae3_row1_col2, #T_aeae3_row1_col3, #T_aeae3_row2_col0, #T_aeae3_row2_col1, #T_aeae3_row2_col2, #T_aeae3_row2_col3, #T_aeae3_row3_col2, #T_aeae3_row4_col0, #T_aeae3_row4_col1, #T_aeae3_row7_col2 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row0_col4, #T_aeae3_row0_col5, #T_aeae3_row0_col6, #T_aeae3_row0_col7, #T_aeae3_row1_col4, #T_aeae3_row1_col5, #T_aeae3_row1_col6, #T_aeae3_row1_col7, #T_aeae3_row2_col4, #T_aeae3_row2_col5, #T_aeae3_row2_col6, #T_aeae3_row2_col7 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row3_col0, #T_aeae3_row3_col3 {\n",
|
|||
|
" background-color: #a0da39;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row3_col1 {\n",
|
|||
|
" background-color: #98d83e;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row3_col4 {\n",
|
|||
|
" background-color: #d9586a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row3_col5, #T_aeae3_row3_col6 {\n",
|
|||
|
" background-color: #d8576b;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row3_col7 {\n",
|
|||
|
" background-color: #d5546e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row4_col2 {\n",
|
|||
|
" background-color: #7fd34e;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row4_col3 {\n",
|
|||
|
" background-color: #5cc863;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row4_col4 {\n",
|
|||
|
" background-color: #d7566c;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row4_col5, #T_aeae3_row4_col6 {\n",
|
|||
|
" background-color: #d45270;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row4_col7 {\n",
|
|||
|
" background-color: #cb4679;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row5_col0 {\n",
|
|||
|
" background-color: #a2da37;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row5_col1 {\n",
|
|||
|
" background-color: #95d840;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row5_col2 {\n",
|
|||
|
" background-color: #1e9b8a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row5_col3 {\n",
|
|||
|
" background-color: #20928c;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row5_col4 {\n",
|
|||
|
" background-color: #c8437b;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row5_col5 {\n",
|
|||
|
" background-color: #c5407e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row5_col6 {\n",
|
|||
|
" background-color: #b02991;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row5_col7 {\n",
|
|||
|
" background-color: #9410a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row6_col0 {\n",
|
|||
|
" background-color: #9bd93c;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row6_col1 {\n",
|
|||
|
" background-color: #84d44b;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row6_col2, #T_aeae3_row6_col3, #T_aeae3_row7_col0, #T_aeae3_row7_col1 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row6_col4 {\n",
|
|||
|
" background-color: #c13b82;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row6_col5 {\n",
|
|||
|
" background-color: #c03a83;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row6_col6 {\n",
|
|||
|
" background-color: #9814a0;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row6_col7 {\n",
|
|||
|
" background-color: #7a02a8;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row7_col3 {\n",
|
|||
|
" background-color: #63cb5f;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_aeae3_row7_col4, #T_aeae3_row7_col5, #T_aeae3_row7_col6, #T_aeae3_row7_col7 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_aeae3\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_aeae3_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
|||
|
" <th id=\"T_aeae3_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
|||
|
" <th id=\"T_aeae3_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
|||
|
" <th id=\"T_aeae3_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
|||
|
" <th id=\"T_aeae3_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
|||
|
" <th id=\"T_aeae3_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_aeae3_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
|||
|
" <th id=\"T_aeae3_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_aeae3_level0_row0\" class=\"row_heading level0 row0\" >decision_tree</th>\n",
|
|||
|
" <td id=\"T_aeae3_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_aeae3_level0_row1\" class=\"row_heading level0 row1\" >gradient_boosting</th>\n",
|
|||
|
" <td id=\"T_aeae3_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_aeae3_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_aeae3_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_aeae3_level0_row3\" class=\"row_heading level0 row3\" >ridge</th>\n",
|
|||
|
" <td id=\"T_aeae3_row3_col0\" class=\"data row3 col0\" >0.980851</td>\n",
|
|||
|
" <td id=\"T_aeae3_row3_col1\" class=\"data row3 col1\" >0.960630</td>\n",
|
|||
|
" <td id=\"T_aeae3_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row3_col3\" class=\"data row3 col3\" >0.983871</td>\n",
|
|||
|
" <td id=\"T_aeae3_row3_col4\" class=\"data row3 col4\" >0.995673</td>\n",
|
|||
|
" <td id=\"T_aeae3_row3_col5\" class=\"data row3 col5\" >0.986538</td>\n",
|
|||
|
" <td id=\"T_aeae3_row3_col6\" class=\"data row3 col6\" >0.990333</td>\n",
|
|||
|
" <td id=\"T_aeae3_row3_col7\" class=\"data row3 col7\" >0.972112</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_aeae3_level0_row4\" class=\"row_heading level0 row4\" >logistic</th>\n",
|
|||
|
" <td id=\"T_aeae3_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row4_col2\" class=\"data row4 col2\" >0.943601</td>\n",
|
|||
|
" <td id=\"T_aeae3_row4_col3\" class=\"data row4 col3\" >0.830645</td>\n",
|
|||
|
" <td id=\"T_aeae3_row4_col4\" class=\"data row4 col4\" >0.987500</td>\n",
|
|||
|
" <td id=\"T_aeae3_row4_col5\" class=\"data row4 col5\" >0.959615</td>\n",
|
|||
|
" <td id=\"T_aeae3_row4_col6\" class=\"data row4 col6\" >0.970982</td>\n",
|
|||
|
" <td id=\"T_aeae3_row4_col7\" class=\"data row4 col7\" >0.907489</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_aeae3_level0_row5\" class=\"row_heading level0 row5\" >mlp</th>\n",
|
|||
|
" <td id=\"T_aeae3_row5_col0\" class=\"data row5 col0\" >0.990798</td>\n",
|
|||
|
" <td id=\"T_aeae3_row5_col1\" class=\"data row5 col1\" >0.953125</td>\n",
|
|||
|
" <td id=\"T_aeae3_row5_col2\" class=\"data row5 col2\" >0.700651</td>\n",
|
|||
|
" <td id=\"T_aeae3_row5_col3\" class=\"data row5 col3\" >0.491935</td>\n",
|
|||
|
" <td id=\"T_aeae3_row5_col4\" class=\"data row5 col4\" >0.932212</td>\n",
|
|||
|
" <td id=\"T_aeae3_row5_col5\" class=\"data row5 col5\" >0.873077</td>\n",
|
|||
|
" <td id=\"T_aeae3_row5_col6\" class=\"data row5 col6\" >0.820839</td>\n",
|
|||
|
" <td id=\"T_aeae3_row5_col7\" class=\"data row5 col7\" >0.648936</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_aeae3_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
|
|||
|
" <td id=\"T_aeae3_row6_col0\" class=\"data row6 col0\" >0.968531</td>\n",
|
|||
|
" <td id=\"T_aeae3_row6_col1\" class=\"data row6 col1\" >0.907407</td>\n",
|
|||
|
" <td id=\"T_aeae3_row6_col2\" class=\"data row6 col2\" >0.600868</td>\n",
|
|||
|
" <td id=\"T_aeae3_row6_col3\" class=\"data row6 col3\" >0.395161</td>\n",
|
|||
|
" <td id=\"T_aeae3_row6_col4\" class=\"data row6 col4\" >0.907212</td>\n",
|
|||
|
" <td id=\"T_aeae3_row6_col5\" class=\"data row6 col5\" >0.846154</td>\n",
|
|||
|
" <td id=\"T_aeae3_row6_col6\" class=\"data row6 col6\" >0.741633</td>\n",
|
|||
|
" <td id=\"T_aeae3_row6_col7\" class=\"data row6 col7\" >0.550562</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_aeae3_level0_row7\" class=\"row_heading level0 row7\" >naive_bayes</th>\n",
|
|||
|
" <td id=\"T_aeae3_row7_col0\" class=\"data row7 col0\" >0.369984</td>\n",
|
|||
|
" <td id=\"T_aeae3_row7_col1\" class=\"data row7 col1\" >0.260546</td>\n",
|
|||
|
" <td id=\"T_aeae3_row7_col2\" class=\"data row7 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_aeae3_row7_col3\" class=\"data row7 col3\" >0.846774</td>\n",
|
|||
|
" <td id=\"T_aeae3_row7_col4\" class=\"data row7 col4\" >0.622596</td>\n",
|
|||
|
" <td id=\"T_aeae3_row7_col5\" class=\"data row7 col5\" >0.390385</td>\n",
|
|||
|
" <td id=\"T_aeae3_row7_col6\" class=\"data row7 col6\" >0.540129</td>\n",
|
|||
|
" <td id=\"T_aeae3_row7_col7\" class=\"data row7 col7\" >0.398482</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x29dde68f2c0>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 51,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
|||
|
" [\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" \"Accuracy_train\",\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_train\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]\n",
|
|||
|
"class_metrics.sort_values(\n",
|
|||
|
" by=\"Accuracy_test\", ascending=False\n",
|
|||
|
").style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Многие модели, включая логистическую регрессию, дерево решений,градиентный бустинг и случайный лес, показали выдающиеся результаты по всем метрикам:\n",
|
|||
|
"\n",
|
|||
|
"- Accuracy: Все модели, кроме MLP, достигли идеальной точности (1.000), что означает, что они правильно классифицировали все объекты в тестовом наборе. MLP показала высокую точность 0.996, что все равно является отличным результатом.\n",
|
|||
|
"- F1: Аналогично, все модели показали идеальное значение F1-меры (1.000), кроме MLP, которая показала значение 0.991.\n",
|
|||
|
"- ROC AUC: Все модели достигли максимального значения ROC AUC (1.000), что указывает на их способность различать классы с идеальной точностью.\n",
|
|||
|
"- Cohen's Kappa: Идеальное значение Cohen's Kappa (1.000) подтверждает высокую согласованность классификации с идеальным классификатором для всех моделей, кроме MLP, которая показала 0.989.\n",
|
|||
|
"- MCC: Идеальное значение MCC (1.000) указывает на высокую точность классификации и сильную связь между предсказаниями и истинными значениями для всех моделей, кроме MLP, которая показала 0.989.\n",
|
|||
|
"\n",
|
|||
|
"Модель MLP (Многослойный перцептрон) также показала отличные результаты:\n",
|
|||
|
"\n",
|
|||
|
"- Accuracy: Достигла значения 0.996, что немного ниже идеального, но все еще очень высокий результат.\n",
|
|||
|
"- F1: Значение F1-меры равно 0.991, что также указывает на высокую эффективность модели.\n",
|
|||
|
"- ROC AUC: MLP достигает идеального значения ROC AUC (1.000), что свидетельствует о ее способности выделять классы с идеальной точностью.\n",
|
|||
|
"- Cohen's Kappa: Высокое значение Cohen's Kappa (0.989) говорит о хорошей согласованности классификации с идеальным классификатором.\n",
|
|||
|
"- MCC: Высокое значение MCC (0.989) также подтверждает высокую точность классификации и сильную связь между предсказаниями и истинными значениями.\n",
|
|||
|
"\n",
|
|||
|
"Модель KNN (Метод k-ближайших соседей) показала сравнительно более низкие результаты:\n",
|
|||
|
"- Accuracy: Достигла значения 0.958, что ниже идеального, но все еще является приемлемым результатом.\n",
|
|||
|
"- F1: Значение F1-меры равно 0.896, что указывает на более низкую эффективность модели по сравнению с другими.\n",
|
|||
|
"- ROC AUC: KNN достигает значения ROC AUC 0.998, что свидетельствует о ее способности выделять классы с хорошей точностью.\n",
|
|||
|
"- Cohen's Kappa: Значение Cohen's Kappa (0.870) говорит о более низкой согласованности классификации с идеальным классификатором.\n",
|
|||
|
"- MCC: Значение MCC (0.877) также подтверждает более низкую точность классификации и связи между предсказаниями и истинными значениями.\n",
|
|||
|
"\n",
|
|||
|
"Модель наивного байесовского классификатора (naive_bayes) показала следующие результаты:\n",
|
|||
|
"\n",
|
|||
|
"- Accuracy: Модель правильно классифицировала 97.88% объектов в тестовом наборе. Это довольно хороший результат, но не идеальный.\n",
|
|||
|
"- F1-мера: Значение F1-меры 0.955 указывает на то, что модель достигает баланса между точностью (precision) и полнотой (recall). Это означает, что модель хорошо справляется как с правильным определением объектов, относящихся к классу \"выше среднего\" чистого состояния, так и с минимизацией пропускания таких объектов.\n",
|
|||
|
"- ROC AUC: Модель достигла значения ROC AUC 0.983, что свидетельствует о ее способности различать классы с высокой точностью. \n",
|
|||
|
"- Cohen's Kappa: Значение 0.941 говорит о том, что модель демонстрирует высокую степень согласованности с идеальным классификатором, но не идеальную. \n",
|
|||
|
"- MCC: MCC 0.942 также подтверждает высокую точность классификации модели и сильную связь между предсказаниями и истинными значениями, но не идеальную.\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"__ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса__"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 52,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_666e5_row0_col0, #T_666e5_row0_col1, #T_666e5_row1_col0, #T_666e5_row1_col1, #T_666e5_row2_col0, #T_666e5_row2_col1 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row0_col2, #T_666e5_row0_col3, #T_666e5_row0_col4, #T_666e5_row1_col2, #T_666e5_row1_col3, #T_666e5_row1_col4, #T_666e5_row2_col2, #T_666e5_row2_col3, #T_666e5_row2_col4, #T_666e5_row3_col2, #T_666e5_row4_col2 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row3_col0 {\n",
|
|||
|
" background-color: #a2da37;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row3_col1 {\n",
|
|||
|
" background-color: #9bd93c;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row3_col3, #T_666e5_row3_col4 {\n",
|
|||
|
" background-color: #d6556d;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row4_col0 {\n",
|
|||
|
" background-color: #95d840;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row4_col1 {\n",
|
|||
|
" background-color: #7cd250;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row4_col3, #T_666e5_row4_col4 {\n",
|
|||
|
" background-color: #cd4a76;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row5_col0 {\n",
|
|||
|
" background-color: #6ece58;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row5_col1 {\n",
|
|||
|
" background-color: #25ac82;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row5_col2 {\n",
|
|||
|
" background-color: #cc4778;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row5_col3 {\n",
|
|||
|
" background-color: #a82296;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row5_col4 {\n",
|
|||
|
" background-color: #ac2694;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row6_col0 {\n",
|
|||
|
" background-color: #63cb5f;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row6_col1 {\n",
|
|||
|
" background-color: #1e9b8a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row6_col2 {\n",
|
|||
|
" background-color: #bf3984;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row6_col3 {\n",
|
|||
|
" background-color: #9814a0;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row6_col4 {\n",
|
|||
|
" background-color: #9e199d;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row7_col0, #T_666e5_row7_col1 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_666e5_row7_col2, #T_666e5_row7_col3, #T_666e5_row7_col4 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_666e5\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_666e5_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_666e5_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
|||
|
" <th id=\"T_666e5_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
|||
|
" <th id=\"T_666e5_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
|||
|
" <th id=\"T_666e5_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_666e5_level0_row0\" class=\"row_heading level0 row0\" >decision_tree</th>\n",
|
|||
|
" <td id=\"T_666e5_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_666e5_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_666e5_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_666e5_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_666e5_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_666e5_level0_row1\" class=\"row_heading level0 row1\" >gradient_boosting</th>\n",
|
|||
|
" <td id=\"T_666e5_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_666e5_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_666e5_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_666e5_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_666e5_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_666e5_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
|
|||
|
" <td id=\"T_666e5_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_666e5_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_666e5_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_666e5_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_666e5_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_666e5_level0_row3\" class=\"row_heading level0 row3\" >ridge</th>\n",
|
|||
|
" <td id=\"T_666e5_row3_col0\" class=\"data row3 col0\" >0.986538</td>\n",
|
|||
|
" <td id=\"T_666e5_row3_col1\" class=\"data row3 col1\" >0.972112</td>\n",
|
|||
|
" <td id=\"T_666e5_row3_col2\" class=\"data row3 col2\" >0.999471</td>\n",
|
|||
|
" <td id=\"T_666e5_row3_col3\" class=\"data row3 col3\" >0.963241</td>\n",
|
|||
|
" <td id=\"T_666e5_row3_col4\" class=\"data row3 col4\" >0.963361</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_666e5_level0_row4\" class=\"row_heading level0 row4\" >logistic</th>\n",
|
|||
|
" <td id=\"T_666e5_row4_col0\" class=\"data row4 col0\" >0.959615</td>\n",
|
|||
|
" <td id=\"T_666e5_row4_col1\" class=\"data row4 col1\" >0.907489</td>\n",
|
|||
|
" <td id=\"T_666e5_row4_col2\" class=\"data row4 col2\" >0.999430</td>\n",
|
|||
|
" <td id=\"T_666e5_row4_col3\" class=\"data row4 col3\" >0.881941</td>\n",
|
|||
|
" <td id=\"T_666e5_row4_col4\" class=\"data row4 col4\" >0.888152</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_666e5_level0_row5\" class=\"row_heading level0 row5\" >mlp</th>\n",
|
|||
|
" <td id=\"T_666e5_row5_col0\" class=\"data row5 col0\" >0.873077</td>\n",
|
|||
|
" <td id=\"T_666e5_row5_col1\" class=\"data row5 col1\" >0.648936</td>\n",
|
|||
|
" <td id=\"T_666e5_row5_col2\" class=\"data row5 col2\" >0.933447</td>\n",
|
|||
|
" <td id=\"T_666e5_row5_col3\" class=\"data row5 col3\" >0.580891</td>\n",
|
|||
|
" <td id=\"T_666e5_row5_col4\" class=\"data row5 col4\" >0.628281</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_666e5_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
|
|||
|
" <td id=\"T_666e5_row6_col0\" class=\"data row6 col0\" >0.846154</td>\n",
|
|||
|
" <td id=\"T_666e5_row6_col1\" class=\"data row6 col1\" >0.550562</td>\n",
|
|||
|
" <td id=\"T_666e5_row6_col2\" class=\"data row6 col2\" >0.883207</td>\n",
|
|||
|
" <td id=\"T_666e5_row6_col3\" class=\"data row6 col3\" >0.474535</td>\n",
|
|||
|
" <td id=\"T_666e5_row6_col4\" class=\"data row6 col4\" >0.534367</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_666e5_level0_row7\" class=\"row_heading level0 row7\" >naive_bayes</th>\n",
|
|||
|
" <td id=\"T_666e5_row7_col0\" class=\"data row7 col0\" >0.390385</td>\n",
|
|||
|
" <td id=\"T_666e5_row7_col1\" class=\"data row7 col1\" >0.398482</td>\n",
|
|||
|
" <td id=\"T_666e5_row7_col2\" class=\"data row7 col2\" >0.547124</td>\n",
|
|||
|
" <td id=\"T_666e5_row7_col3\" class=\"data row7 col3\" >0.053166</td>\n",
|
|||
|
" <td id=\"T_666e5_row7_col4\" class=\"data row7 col4\" >0.096181</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x29dd9191940>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 52,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
|
|||
|
" [\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"]\n",
|
|||
|
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" ],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"- Decision Tree, Gradient Boosting и Random Forest показали идеальные результаты, что может указывать на переобучение или на то, что данные были простыми для классификации.\n",
|
|||
|
"\n",
|
|||
|
"- Ridge и Logistic Regression показали высокие результаты, что указывает на их эффективность в данной задаче классификации.\n",
|
|||
|
"\n",
|
|||
|
"- MLP показала средние результаты, что может указывать на необходимость настройки гиперпараметров или использования более сложной архитектуры.\n",
|
|||
|
"\n",
|
|||
|
"- KNN и Naive Bayes показали низкие результаты, что указывает на их неэффективность в данной задаче классификации."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 53,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'decision_tree'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
|
|||
|
"\n",
|
|||
|
"display(best_model)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Вывод данных с ошибкой предсказания для оценки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 55,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'Error items count: 0'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Преобразование тестовых данных\n",
|
|||
|
"preprocessing_result = pipeline_end.transform(X_test)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Получение предсказаний лучшей модели\n",
|
|||
|
"y_pred = class_models[best_model][\"preds\"]\n",
|
|||
|
"\n",
|
|||
|
"# Нахождение индексов ошибок\n",
|
|||
|
"error_index = y_test[y_test[\"above_average_networth\"] != y_pred].index.tolist() # Изменено на \"above_average_networth\"\n",
|
|||
|
"display(f\"Error items count: {len(error_index)}\")\n",
|
|||
|
"\n",
|
|||
|
"# Создание DataFrame с ошибочными объектами\n",
|
|||
|
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
|
|||
|
"error_df = X_test.loc[error_index].copy()\n",
|
|||
|
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
|
|||
|
"error_df = error_df.sort_index() # Сортировка по индексу"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Пример использования обученной модели (конвейера) для предсказания"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 59,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Rank</th>\n",
|
|||
|
" <th>Name</th>\n",
|
|||
|
" <th>Networth</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Country</th>\n",
|
|||
|
" <th>Source</th>\n",
|
|||
|
" <th>Industry</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1701</th>\n",
|
|||
|
" <td>1645</td>\n",
|
|||
|
" <td>Zugen Ni</td>\n",
|
|||
|
" <td>1.8</td>\n",
|
|||
|
" <td>65</td>\n",
|
|||
|
" <td>China</td>\n",
|
|||
|
" <td>appliances</td>\n",
|
|||
|
" <td>Manufacturing</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Rank Name Networth Age Country Source Industry\n",
|
|||
|
"1701 1645 Zugen Ni 1.8 65 China appliances Manufacturing "
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Networth</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Country_Argentina</th>\n",
|
|||
|
" <th>Country_Australia</th>\n",
|
|||
|
" <th>Country_Austria</th>\n",
|
|||
|
" <th>Country_Barbados</th>\n",
|
|||
|
" <th>Country_Belgium</th>\n",
|
|||
|
" <th>Country_Belize</th>\n",
|
|||
|
" <th>Country_Brazil</th>\n",
|
|||
|
" <th>Country_Bulgaria</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>Industry_Manufacturing</th>\n",
|
|||
|
" <th>Industry_Media & Entertainment</th>\n",
|
|||
|
" <th>Industry_Metals & Mining</th>\n",
|
|||
|
" <th>Industry_Real Estate</th>\n",
|
|||
|
" <th>Industry_Service</th>\n",
|
|||
|
" <th>Industry_Sports</th>\n",
|
|||
|
" <th>Industry_Technology</th>\n",
|
|||
|
" <th>Industry_Telecom</th>\n",
|
|||
|
" <th>Industry_diversified</th>\n",
|
|||
|
" <th>Networth_per_Age</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1701</th>\n",
|
|||
|
" <td>-0.279582</td>\n",
|
|||
|
" <td>0.038693</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>-7.22566</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>1 rows × 857 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Networth Age Country_Argentina Country_Australia \\\n",
|
|||
|
"1701 -0.279582 0.038693 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Country_Austria Country_Barbados Country_Belgium Country_Belize \\\n",
|
|||
|
"1701 0.0 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Country_Brazil Country_Bulgaria ... Industry_Manufacturing \\\n",
|
|||
|
"1701 0.0 0.0 ... 1.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Media & Entertainment Industry_Metals & Mining \\\n",
|
|||
|
"1701 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Real Estate Industry_Service Industry_Sports \\\n",
|
|||
|
"1701 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Technology Industry_Telecom Industry_diversified \\\n",
|
|||
|
"1701 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Networth_per_Age \n",
|
|||
|
"1701 -7.22566 \n",
|
|||
|
"\n",
|
|||
|
"[1 rows x 857 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'predicted: 0 (proba: [1. 0.])'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'real: 0'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"\n",
|
|||
|
"# Выбираем лучшую модель\n",
|
|||
|
"model = class_models[best_model][\"pipeline\"]\n",
|
|||
|
"\n",
|
|||
|
"# Выбираем позицию объекта для анализа\n",
|
|||
|
"example_position = 127\n",
|
|||
|
"\n",
|
|||
|
"# Получаем исходные данные для объекта по позиции\n",
|
|||
|
"test = pd.DataFrame(X_test.iloc[example_position, :]).T\n",
|
|||
|
"display(test)\n",
|
|||
|
"\n",
|
|||
|
"# Получаем преобразованные данные для объекта по позиции\n",
|
|||
|
"test_preprocessed = pd.DataFrame(preprocessed_df.iloc[example_position, :]).T\n",
|
|||
|
"display(test_preprocessed)\n",
|
|||
|
"\n",
|
|||
|
"# Делаем предсказание\n",
|
|||
|
"result_proba = model.predict_proba(test)[0]\n",
|
|||
|
"result = model.predict(test)[0]\n",
|
|||
|
"\n",
|
|||
|
"# Получаем реальное значение\n",
|
|||
|
"real = int(y_test.iloc[example_position].values[0])\n",
|
|||
|
"\n",
|
|||
|
"# Выводим результаты\n",
|
|||
|
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
|
|||
|
"display(f\"real: {real}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Подбор гиперпараметров методом поиска по сетке"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 60,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"{'model__criterion': 'gini',\n",
|
|||
|
" 'model__max_depth': 5,\n",
|
|||
|
" 'model__max_features': 'sqrt',\n",
|
|||
|
" 'model__n_estimators': 50}"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 60,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.model_selection import GridSearchCV\n",
|
|||
|
"\n",
|
|||
|
"optimized_model_type = \"random_forest\"\n",
|
|||
|
"\n",
|
|||
|
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
|
|||
|
"\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" \"model__n_estimators\": [10, 50, 100],\n",
|
|||
|
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
|
|||
|
" \"model__max_depth\": [5, 7, 10],\n",
|
|||
|
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"gs_optomizer = GridSearchCV(\n",
|
|||
|
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
|
|||
|
")\n",
|
|||
|
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
|
|||
|
"gs_optomizer.best_params_"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"__Обучение модели с новыми гиперпараметрами__"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 61,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
|||
|
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
|||
|
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"optimized_model = ensemble.RandomForestClassifier(\n",
|
|||
|
" random_state=random_state,\n",
|
|||
|
" criterion=\"gini\",\n",
|
|||
|
" max_depth=5,\n",
|
|||
|
" max_features=\"log2\",\n",
|
|||
|
" n_estimators=10,\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"result = {}\n",
|
|||
|
"\n",
|
|||
|
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
|
|||
|
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
|
|||
|
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
|
|||
|
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
|
|||
|
"\n",
|
|||
|
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
|
|||
|
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
|
|||
|
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
|
|||
|
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"__Формирование данных для оценки старой и новой версии модели__"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 62,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
|
|||
|
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
|||
|
" data=class_models[optimized_model_type]\n",
|
|||
|
")\n",
|
|||
|
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
|
|||
|
" data=result\n",
|
|||
|
")\n",
|
|||
|
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
|
|||
|
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"__Оценка параметров старой и новой модели__"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 63,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_9b5df_row0_col0, #T_9b5df_row0_col1, #T_9b5df_row0_col2, #T_9b5df_row0_col3 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9b5df_row0_col4, #T_9b5df_row0_col5, #T_9b5df_row0_col6, #T_9b5df_row0_col7 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9b5df_row1_col0, #T_9b5df_row1_col1, #T_9b5df_row1_col2, #T_9b5df_row1_col3 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_9b5df_row1_col4, #T_9b5df_row1_col5, #T_9b5df_row1_col6, #T_9b5df_row1_col7 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_9b5df\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_9b5df_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
|
|||
|
" <th id=\"T_9b5df_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
|
|||
|
" <th id=\"T_9b5df_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
|
|||
|
" <th id=\"T_9b5df_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
|
|||
|
" <th id=\"T_9b5df_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
|
|||
|
" <th id=\"T_9b5df_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_9b5df_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
|
|||
|
" <th id=\"T_9b5df_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"index_name level0\" >Name</th>\n",
|
|||
|
" <th class=\"blank col0\" > </th>\n",
|
|||
|
" <th class=\"blank col1\" > </th>\n",
|
|||
|
" <th class=\"blank col2\" > </th>\n",
|
|||
|
" <th class=\"blank col3\" > </th>\n",
|
|||
|
" <th class=\"blank col4\" > </th>\n",
|
|||
|
" <th class=\"blank col5\" > </th>\n",
|
|||
|
" <th class=\"blank col6\" > </th>\n",
|
|||
|
" <th class=\"blank col7\" > </th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_9b5df_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
|||
|
" <td id=\"T_9b5df_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_9b5df_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_9b5df_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_9b5df_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_9b5df_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_9b5df_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_9b5df_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_9b5df_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_9b5df_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
|||
|
" <td id=\"T_9b5df_row1_col0\" class=\"data row1 col0\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_9b5df_row1_col1\" class=\"data row1 col1\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_9b5df_row1_col2\" class=\"data row1 col2\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_9b5df_row1_col3\" class=\"data row1 col3\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_9b5df_row1_col4\" class=\"data row1 col4\" >0.778365</td>\n",
|
|||
|
" <td id=\"T_9b5df_row1_col5\" class=\"data row1 col5\" >0.761538</td>\n",
|
|||
|
" <td id=\"T_9b5df_row1_col6\" class=\"data row1 col6\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_9b5df_row1_col7\" class=\"data row1 col7\" >0.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x29ddbb6fb60>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 63,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"optimized_metrics[\n",
|
|||
|
" [\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" \"Accuracy_train\",\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_train\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"].style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Precision_train\",\n",
|
|||
|
" \"Precision_test\",\n",
|
|||
|
" \"Recall_train\",\n",
|
|||
|
" \"Recall_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Как для обучающей (Precision_train), так и для тестовой (Precision_test) выборки обе модели достигли идеальных значений 1.000000. Это указывает на то, что модели очень точно классифицируют положительные образцы, не пропуская их."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 64,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style type=\"text/css\">\n",
|
|||
|
"#T_d0354_row0_col0, #T_d0354_row0_col1 {\n",
|
|||
|
" background-color: #a8db34;\n",
|
|||
|
" color: #000000;\n",
|
|||
|
"}\n",
|
|||
|
"#T_d0354_row0_col2, #T_d0354_row0_col3, #T_d0354_row0_col4 {\n",
|
|||
|
" background-color: #da5a6a;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_d0354_row1_col0, #T_d0354_row1_col1 {\n",
|
|||
|
" background-color: #26818e;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"#T_d0354_row1_col2, #T_d0354_row1_col3, #T_d0354_row1_col4 {\n",
|
|||
|
" background-color: #4e02a2;\n",
|
|||
|
" color: #f1f1f1;\n",
|
|||
|
"}\n",
|
|||
|
"</style>\n",
|
|||
|
"<table id=\"T_d0354\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"blank level0\" > </th>\n",
|
|||
|
" <th id=\"T_d0354_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
|
|||
|
" <th id=\"T_d0354_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
|
|||
|
" <th id=\"T_d0354_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
|
|||
|
" <th id=\"T_d0354_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
|
|||
|
" <th id=\"T_d0354_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th class=\"index_name level0\" >Name</th>\n",
|
|||
|
" <th class=\"blank col0\" > </th>\n",
|
|||
|
" <th class=\"blank col1\" > </th>\n",
|
|||
|
" <th class=\"blank col2\" > </th>\n",
|
|||
|
" <th class=\"blank col3\" > </th>\n",
|
|||
|
" <th class=\"blank col4\" > </th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_d0354_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
|
|||
|
" <td id=\"T_d0354_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_d0354_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_d0354_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_d0354_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
|
|||
|
" <td id=\"T_d0354_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th id=\"T_d0354_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
|
|||
|
" <td id=\"T_d0354_row1_col0\" class=\"data row1 col0\" >0.761538</td>\n",
|
|||
|
" <td id=\"T_d0354_row1_col1\" class=\"data row1 col1\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_d0354_row1_col2\" class=\"data row1 col2\" >0.999572</td>\n",
|
|||
|
" <td id=\"T_d0354_row1_col3\" class=\"data row1 col3\" >0.000000</td>\n",
|
|||
|
" <td id=\"T_d0354_row1_col4\" class=\"data row1 col4\" >0.000000</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"<pandas.io.formats.style.Styler at 0x29ddbb6dd90>"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 64,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"optimized_metrics[\n",
|
|||
|
" [\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" ]\n",
|
|||
|
"].style.background_gradient(\n",
|
|||
|
" cmap=\"plasma\",\n",
|
|||
|
" low=0.3,\n",
|
|||
|
" high=1,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"ROC_AUC_test\",\n",
|
|||
|
" \"MCC_test\",\n",
|
|||
|
" \"Cohen_kappa_test\",\n",
|
|||
|
" ],\n",
|
|||
|
").background_gradient(\n",
|
|||
|
" cmap=\"viridis\",\n",
|
|||
|
" low=1,\n",
|
|||
|
" high=0.3,\n",
|
|||
|
" subset=[\n",
|
|||
|
" \"Accuracy_test\",\n",
|
|||
|
" \"F1_test\",\n",
|
|||
|
" ],\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Оба варианта модели продемонстрировали безупречную точность классификации, достигнув значения 1.000000. Это свидетельствует о том, что модели точно классифицировали все тестовые примеры, не допустив никаких ошибок в предсказаниях."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 67,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA6cAAAGsCAYAAAAhRNGaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABqgUlEQVR4nO3dd3QU1cPG8WcT0iCNUBJCr0mQJkWMUqUEBKQpLQgo4isiKAoiKB2MothAQSRSNCg2VEBBQErAKEgXECUioBBASkIoaTvvH/kxulJMyCa7G76fc+bozszeuZOEffbOvXPHYhiGIQAAAAAAHMjN0RUAAAAAAIDGKQAAAADA4WicAgAAAAAcjsYpAAAAAMDhaJwCAAAAAByOxikAAAAAwOFonAIAAAAAHI7GKQAAAADA4Yo4ugIAAOTWpUuXlJ6ebrfyPD095e3tbbfyAADIDXItG41TAIBLuXTpkipX9FXSiSy7lRkSEqKDBw+6ZJADAFwbufY3GqcAAJeSnp6upBNZOri1ovz98n53Sso5qyo3OKT09HSXC3EAgOsj1/5G4xQA4JL8/dzsEuIAADgDco3GKQDARWUZVmUZ9ikHAABHI9donAIAXJRVhqzKe4rbowwAAPKKXONRMgAAAAAAJ0DPKQDAJVlllT0GLtmnFAAA8oZco3EKAHBRWYahLCPvQ5fsUQYAAHlFrjGsFwAAAADgBOg5BQC4JCaOAAAUJuQajVMAgIuyylDWTR7iAIDCg1xjWC8AAAAAwAnQcwoAcEkMfwIAFCbkGj2nAAAAAAAnQM8pAMAlMeU+AKAwIddonAIAXJT1f4s9ygEAwNHINYb1AgAAAACcAD2nAACXlGWnKfftUQYAAHlFrtE4BQC4qCwje7FHOQAAOBq5xrBeAAAAAIAToOcUAOCSmDgCAFCYkGs0TgEALsoqi7JksUs5AAA4GrnGsF4AAAAAgBOg5xQA4JKsRvZij3IAAHA0co2eUwAAAACAE6DnFADgkrLsdG+OPcoAACCvyDUapwAAF0WIAwAKE3KNYb0AAAAAACdAzykAwCVZDYushh2m3LdDGQAA5BW5RuMUAOCiGP4EAChMyDWG9QIAAAAAnAA9pwAAl5QlN2XZ4Rprlh3qAgBAXpFrNE4BAC7KsNO9OYYL35sDACg8yDWG9QIAAAAAnAA9pwAAl8TEEQCAwoRco3EKAHBRWYabsgw73Jtj2KEyAADkEbnGsF4AAAAAgBOg5xQA4JKssshqh2usVrnwJWYAQKFBrtFzCgAAAABwAvScAgBcEhNHAAAKE3KNxikAwEXZb+II1x3+BAAoPMg1hvUCAAAAAJwAjVMAgEvKnjjCPktuzJo1S3Xq1JG/v7/8/f0VGRmpr7/+2tzeokULWSwWm+WRRx6xKePw4cPq0KGDihYtqtKlS2vkyJHKzMy0y88FAOCayDWG9QLIJ/Pnz9cDDzyggwcPqlKlStfdt1KlSmrRooXmz59fIHVD4WCVm7IcMKthuXLl9MILL6h69eoyDEMLFixQ586dtX37dt1yyy2SpEGDBmnSpEnme4oWLWr+f1ZWljp06KCQkBB99913OnbsmPr16ycPDw89//zzeT4fAIBrItfoOQWQS3v27FHfvn1VtmxZeXl5KTQ0VNHR0dqzZ4+jqwYUiE6dOunuu+9W9erVVaNGDU2dOlW+vr76/vvvzX2KFi2qkJAQc/H39ze3ffPNN9q7d6/ef/991atXT+3bt9fkyZP15ptvKj093RGnBNwU5s+fL4vFIm9vb/35559XbG/RooVq1arlgJoBjuVMuUbjFECOffbZZ6pfv77WrFmjBx54QG+99ZYGDhyotWvXqn79+lqyZImjq4ibyOWJI+yxSFJKSorNkpaW9t91yMrShx9+qPPnzysyMtJcHxcXp5IlS6pWrVoaPXq0Lly4YG5LSEhQ7dq1FRwcbK6LiopSSkoKF3mAApCWlqYXXnjB0dUArkCuMawXQA4lJibq/vvvV5UqVbRhwwaVKlXK3Pb444+radOmuv/++7Vr1y5VqVLFgTXFzcIqN7s+rLx8+fI268ePH68JEyZc9T27d+9WZGSkLl26JF9fXy1ZskQ1a9aUJPXp00cVK1ZUaGiodu3apVGjRmn//v367LPPJElJSUk2AS7JfJ2UlJTn8wFwffXq1dM777yj0aNHKzQ01NHVAUzkGj2nAHLopZde0oULFzRnzhybhqkklSxZUm+//bbOnz+vadOmXbMMwzA0ZcoUlStXTkWLFlXLli3pKYLTOHLkiJKTk81l9OjR19w3LCxMO3bs0A8//KDBgwerf//+2rt3ryTp4YcfVlRUlGrXrq3o6GgtXLhQS5YsUWJiYkGdCoDrGDNmjLKysnLUe/r++++rQYMG8vHxUVBQkHr16qUjR46Y29944w25u7vr7Nmz5rrp06fLYrHoySefNNdlZWXJz89Po0aNsuu5ANfjirlG4xRAjixdulSVKlVS06ZNr7q9WbNmqlSpkpYvX37NMsaNG6exY8eqbt26eumll1SlShW1bdtW58+fz69qoxDLMix2WySZsxReXry8vK55bE9PT1WrVk0NGjRQTEyM6tatq9dff/2q+zZu3FiSdODAAUlSSEiIjh8/brPP5dchISF5/rkAuL7KlSurX79+euedd3T06NFr7jd16lT169dP1atX1yuvvKInnnhCa9asUbNmzczGaNOmTWW1WrVx40bzffHx8XJzc1N8fLy5bvv27UpNTVWzZs3y7bzg+sg1GqcAciA5OVlHjx5V3bp1r7tfnTp19Mcff+jcuXNXbDt58qSmTZumDh06aNmyZRoyZIhiY2M1YMAA/fXXX/lVdaBAWK3Wa97Ls2PHDklSmTJlJEmRkZHavXu3Tpw4Ye6zatUq+fv7m0OoAOSvZ599VpmZmXrxxRevuv3QoUMaP368pkyZog8//FCDBw/WuHHjtHbtWv3xxx966623JEl169aVv7+/2RA1DEMbN25U9+7dzQap9HeD9c477yyYEwTyyFG5RuMUwH+63Nj08/O77n6Xt6ekpFyxbfXq1UpPT9fQoUNlsfz9/K0nnnjCfhXFTSXrf1Pu22PJjdGjR2vDhg36/ffftXv3bo0ePVrr1q1TdHS0EhMTNXnyZG3dulW///67vvzyS/Xr10/NmjVTnTp1JElt27ZVzZo1df/992vnzp1auXKlnnvuOQ0ZMuS6V7UB2E+VKlV0//33a86cOTp27NgV2z/77DNZrVb16NFDf/31l7mEhISoevXqWrt2rSTJzc1Nd9xxhzZs2CBJ2rdvn06dOqVnnnlGhmEoISFBUnbjtFatWgoMDCywc4TrIddonALIgcuNzqv1iP7T9Rqxhw4dkiRVr17dZn2pUqVUvHhxe1QTNxmr4Wa3JTdOnDihfv36KSwsTK1atdKWLVu0cuVKtWnTRp6enlq9erXatm2r8PBwPfXUU+revbuWLl1qvt/d3V3Lli2Tu7u7IiMj1bdvX/Xr18/m+XEA8t9zzz2nzMzMq957+uuvv8owDFWvXl2lSpWyWfbt22fTQ9S0aVNt3bpVFy9eVHx8vMqUKaP69eurbt26Zo/qxo0br3lbDHAZucZsvQByICAgQGXKlNGuXbuuu9+uXbtUtmxZm2dfAYVNbGzsNbeVL19e69ev/88yKlasqK+++sqe1QKQS1WqVFHfvn01Z84cPfPMMzbbrFarLBaLvv76a7m7u1/xXl9fX/P/mzRpooyMDCUkJCg+Pt5shDZt2lTx8fH6+eefdfLkSRqncFrOlGs0TgHkSMeOHfXOO+9o48aNatKkyRXb4+Pj9fvvv+v//u//rvr+ihUrSsq+Gv3PR82cPHlSZ86cyZ9Ko1C7kaFLVy/HsENtALii5557Tu+///4V955WrVpVhmGocuXKqlGjxnXLuO222+Tp6an4+HjFx8dr5MiRkrInCnznnXe0Zs0a8zVwPeQaw3oB5NDIkSPl4+Oj//u//9OpU6dstp0+fVqPPPKIihYtaobyv7Vu3VoeHh6aMWOGDOPvD83XXnstP6uNQswq+8xsaHX0iQBwmKpVq6p
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1000x400 with 4 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False)\n",
|
|||
|
"\n",
|
|||
|
"for index in range(0, len(optimized_metrics)):\n",
|
|||
|
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[\"Below Average\", \"Above Average\"] \n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
" disp.ax_.set_title(optimized_metrics.index[index])\n",
|
|||
|
"\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
|
|||
|
"plt.show()\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"В желтом квадрате мы видим значение 396, что обозначает количество правильно классифицированных объектов, отнесенных к классу \"Below Average\". Это свидетельствует о том, что модель успешно идентифицирует объекты этого класса, минимизируя количество ложных положительных срабатываний.\n",
|
|||
|
"\n",
|
|||
|
"В зеленом квадрате значение 124 указывает на количество правильно классифицированных объектов, отнесенных к классу \"Above Average\". Это также является показателем высокой точности модели в определении объектов данного класса."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Определение достижимого уровня качества модели для второй задачи (задача регрессии)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"__2. Прогнозирование цены закрытия акций:__\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"Описание: Оценить, какая будет цена закрытия акций Starbucks на следующий день или через несколько дней на основе исторических данных.\n",
|
|||
|
"Целевая переменная: Цена закрытия (Close). (среднее значение)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Загрузка данных и создание целевой переменной"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 68,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Среднее значение поля 'Networth': 4.8607499999999995\n",
|
|||
|
" Rank Name Networth Age Country \\\n",
|
|||
|
"0 1 Elon Musk 219.0 50 United States \n",
|
|||
|
"1 2 Jeff Bezos 171.0 58 United States \n",
|
|||
|
"2 3 Bernard Arnault & family 158.0 73 France \n",
|
|||
|
"3 4 Bill Gates 129.0 66 United States \n",
|
|||
|
"4 5 Warren Buffett 118.0 91 United States \n",
|
|||
|
"\n",
|
|||
|
" Source Industry above_average_networth \n",
|
|||
|
"0 Tesla, SpaceX Automotive 1 \n",
|
|||
|
"1 Amazon Technology 1 \n",
|
|||
|
"2 LVMH Fashion & Retail 1 \n",
|
|||
|
"3 Microsoft Technology 1 \n",
|
|||
|
"4 Berkshire Hathaway Finance & Investments 1 \n",
|
|||
|
"Статистическое описание DataFrame:\n",
|
|||
|
" Rank Networth Age above_average_networth\n",
|
|||
|
"count 2600.000000 2600.000000 2600.000000 2600.000000\n",
|
|||
|
"mean 1269.570769 4.860750 64.271923 0.225000\n",
|
|||
|
"std 728.146364 10.659671 13.220607 0.417663\n",
|
|||
|
"min 1.000000 1.000000 19.000000 0.000000\n",
|
|||
|
"25% 637.000000 1.500000 55.000000 0.000000\n",
|
|||
|
"50% 1292.000000 2.400000 64.000000 0.000000\n",
|
|||
|
"75% 1929.000000 4.500000 74.000000 0.000000\n",
|
|||
|
"max 2578.000000 219.000000 100.000000 1.000000\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"from sklearn import set_config\n",
|
|||
|
"\n",
|
|||
|
"set_config(transform_output=\"pandas\")\n",
|
|||
|
"\n",
|
|||
|
"df = pd.read_csv(\"..//static//csv//Forbes Billionaires.csv\")\n",
|
|||
|
"\n",
|
|||
|
"# Опция для настройки генерации случайных чисел (если это нужно для других частей кода)\n",
|
|||
|
"random_state = 42\n",
|
|||
|
"\n",
|
|||
|
"# Вычисление среднего значения поля \"Networth\"\n",
|
|||
|
"average_networth = df['Networth'].mean()\n",
|
|||
|
"print(f\"Среднее значение поля 'Networth': {average_networth}\")\n",
|
|||
|
"\n",
|
|||
|
"# Создание новой колонки, указывающей, выше или ниже среднего значение чистого состояния\n",
|
|||
|
"df['above_average_networth'] = (df['Networth'] > average_networth).astype(int)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод DataFrame с новой колонкой\n",
|
|||
|
"print(df.head())\n",
|
|||
|
"\n",
|
|||
|
"# Примерный анализ данных\n",
|
|||
|
"print(\"Статистическое описание DataFrame:\")\n",
|
|||
|
"print(df.describe())\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи регрессии\n",
|
|||
|
"\n",
|
|||
|
"Целевой признак -- above_average_networth"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 69,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Rank</th>\n",
|
|||
|
" <th>Name</th>\n",
|
|||
|
" <th>Networth</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Country</th>\n",
|
|||
|
" <th>Source</th>\n",
|
|||
|
" <th>Industry</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>582</th>\n",
|
|||
|
" <td>579</td>\n",
|
|||
|
" <td>Alexandra Schoerghuber & family</td>\n",
|
|||
|
" <td>4.9</td>\n",
|
|||
|
" <td>63</td>\n",
|
|||
|
" <td>Germany</td>\n",
|
|||
|
" <td>real estate</td>\n",
|
|||
|
" <td>Real Estate</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>48</th>\n",
|
|||
|
" <td>49</td>\n",
|
|||
|
" <td>He Xiangjian</td>\n",
|
|||
|
" <td>28.3</td>\n",
|
|||
|
" <td>79</td>\n",
|
|||
|
" <td>China</td>\n",
|
|||
|
" <td>home appliances</td>\n",
|
|||
|
" <td>Manufacturing</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1772</th>\n",
|
|||
|
" <td>1729</td>\n",
|
|||
|
" <td>Bruce Mathieson</td>\n",
|
|||
|
" <td>1.7</td>\n",
|
|||
|
" <td>78</td>\n",
|
|||
|
" <td>Australia</td>\n",
|
|||
|
" <td>hotels</td>\n",
|
|||
|
" <td>Food & Beverage</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>964</th>\n",
|
|||
|
" <td>951</td>\n",
|
|||
|
" <td>Pansy Ho</td>\n",
|
|||
|
" <td>3.2</td>\n",
|
|||
|
" <td>59</td>\n",
|
|||
|
" <td>Hong Kong</td>\n",
|
|||
|
" <td>casinos</td>\n",
|
|||
|
" <td>Gambling & Casinos</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2213</th>\n",
|
|||
|
" <td>2190</td>\n",
|
|||
|
" <td>Sasson Dayan & family</td>\n",
|
|||
|
" <td>1.3</td>\n",
|
|||
|
" <td>82</td>\n",
|
|||
|
" <td>Brazil</td>\n",
|
|||
|
" <td>banking</td>\n",
|
|||
|
" <td>Finance & Investments</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1638</th>\n",
|
|||
|
" <td>1579</td>\n",
|
|||
|
" <td>Wang Chou-hsiong</td>\n",
|
|||
|
" <td>1.9</td>\n",
|
|||
|
" <td>81</td>\n",
|
|||
|
" <td>Taiwan</td>\n",
|
|||
|
" <td>footwear</td>\n",
|
|||
|
" <td>Manufacturing</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1095</th>\n",
|
|||
|
" <td>1096</td>\n",
|
|||
|
" <td>Jose Joao Abdalla Filho</td>\n",
|
|||
|
" <td>2.8</td>\n",
|
|||
|
" <td>76</td>\n",
|
|||
|
" <td>Brazil</td>\n",
|
|||
|
" <td>investments</td>\n",
|
|||
|
" <td>Finance & Investments</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1130</th>\n",
|
|||
|
" <td>1096</td>\n",
|
|||
|
" <td>Lin Chen-hai</td>\n",
|
|||
|
" <td>2.8</td>\n",
|
|||
|
" <td>75</td>\n",
|
|||
|
" <td>Taiwan</td>\n",
|
|||
|
" <td>real estate</td>\n",
|
|||
|
" <td>Real Estate</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1294</th>\n",
|
|||
|
" <td>1292</td>\n",
|
|||
|
" <td>Banwari Lal Bawri</td>\n",
|
|||
|
" <td>2.4</td>\n",
|
|||
|
" <td>69</td>\n",
|
|||
|
" <td>India</td>\n",
|
|||
|
" <td>pharmaceuticals</td>\n",
|
|||
|
" <td>Healthcare</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>860</th>\n",
|
|||
|
" <td>851</td>\n",
|
|||
|
" <td>Kuok Khoon Hong</td>\n",
|
|||
|
" <td>3.5</td>\n",
|
|||
|
" <td>72</td>\n",
|
|||
|
" <td>Singapore</td>\n",
|
|||
|
" <td>palm oil</td>\n",
|
|||
|
" <td>Manufacturing</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>2080 rows × 7 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Rank Name Networth Age Country \\\n",
|
|||
|
"582 579 Alexandra Schoerghuber & family 4.9 63 Germany \n",
|
|||
|
"48 49 He Xiangjian 28.3 79 China \n",
|
|||
|
"1772 1729 Bruce Mathieson 1.7 78 Australia \n",
|
|||
|
"964 951 Pansy Ho 3.2 59 Hong Kong \n",
|
|||
|
"2213 2190 Sasson Dayan & family 1.3 82 Brazil \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"1638 1579 Wang Chou-hsiong 1.9 81 Taiwan \n",
|
|||
|
"1095 1096 Jose Joao Abdalla Filho 2.8 76 Brazil \n",
|
|||
|
"1130 1096 Lin Chen-hai 2.8 75 Taiwan \n",
|
|||
|
"1294 1292 Banwari Lal Bawri 2.4 69 India \n",
|
|||
|
"860 851 Kuok Khoon Hong 3.5 72 Singapore \n",
|
|||
|
"\n",
|
|||
|
" Source Industry \n",
|
|||
|
"582 real estate Real Estate \n",
|
|||
|
"48 home appliances Manufacturing \n",
|
|||
|
"1772 hotels Food & Beverage \n",
|
|||
|
"964 casinos Gambling & Casinos \n",
|
|||
|
"2213 banking Finance & Investments \n",
|
|||
|
"... ... ... \n",
|
|||
|
"1638 footwear Manufacturing \n",
|
|||
|
"1095 investments Finance & Investments \n",
|
|||
|
"1130 real estate Real Estate \n",
|
|||
|
"1294 pharmaceuticals Healthcare \n",
|
|||
|
"860 palm oil Manufacturing \n",
|
|||
|
"\n",
|
|||
|
"[2080 rows x 7 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_train'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>above_average_networth</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>582</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>48</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1772</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>964</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2213</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1638</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1095</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1130</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1294</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>860</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>2080 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" above_average_networth\n",
|
|||
|
"582 1\n",
|
|||
|
"48 1\n",
|
|||
|
"1772 0\n",
|
|||
|
"964 0\n",
|
|||
|
"2213 0\n",
|
|||
|
"... ...\n",
|
|||
|
"1638 0\n",
|
|||
|
"1095 0\n",
|
|||
|
"1130 0\n",
|
|||
|
"1294 0\n",
|
|||
|
"860 0\n",
|
|||
|
"\n",
|
|||
|
"[2080 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'X_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Rank</th>\n",
|
|||
|
" <th>Name</th>\n",
|
|||
|
" <th>Networth</th>\n",
|
|||
|
" <th>Age</th>\n",
|
|||
|
" <th>Country</th>\n",
|
|||
|
" <th>Source</th>\n",
|
|||
|
" <th>Industry</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1593</th>\n",
|
|||
|
" <td>1579</td>\n",
|
|||
|
" <td>Guangming Fu & family</td>\n",
|
|||
|
" <td>1.9</td>\n",
|
|||
|
" <td>68</td>\n",
|
|||
|
" <td>China</td>\n",
|
|||
|
" <td>poultry</td>\n",
|
|||
|
" <td>Food & Beverage</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>196</th>\n",
|
|||
|
" <td>197</td>\n",
|
|||
|
" <td>Leon Black</td>\n",
|
|||
|
" <td>10.0</td>\n",
|
|||
|
" <td>70</td>\n",
|
|||
|
" <td>United States</td>\n",
|
|||
|
" <td>private equity</td>\n",
|
|||
|
" <td>Finance & Investments</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>239</th>\n",
|
|||
|
" <td>235</td>\n",
|
|||
|
" <td>Zong Qinghou</td>\n",
|
|||
|
" <td>8.8</td>\n",
|
|||
|
" <td>76</td>\n",
|
|||
|
" <td>China</td>\n",
|
|||
|
" <td>beverages</td>\n",
|
|||
|
" <td>Food & Beverage</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2126</th>\n",
|
|||
|
" <td>2076</td>\n",
|
|||
|
" <td>Kurt Krieger</td>\n",
|
|||
|
" <td>1.4</td>\n",
|
|||
|
" <td>74</td>\n",
|
|||
|
" <td>Germany</td>\n",
|
|||
|
" <td>furniture retailing</td>\n",
|
|||
|
" <td>Fashion & Retail</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1587</th>\n",
|
|||
|
" <td>1579</td>\n",
|
|||
|
" <td>Chen Kaichen</td>\n",
|
|||
|
" <td>1.9</td>\n",
|
|||
|
" <td>64</td>\n",
|
|||
|
" <td>China</td>\n",
|
|||
|
" <td>household chemicals</td>\n",
|
|||
|
" <td>Manufacturing</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1778</th>\n",
|
|||
|
" <td>1729</td>\n",
|
|||
|
" <td>Jorge Perez</td>\n",
|
|||
|
" <td>1.7</td>\n",
|
|||
|
" <td>72</td>\n",
|
|||
|
" <td>United States</td>\n",
|
|||
|
" <td>real estate</td>\n",
|
|||
|
" <td>Real Estate</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>166</th>\n",
|
|||
|
" <td>167</td>\n",
|
|||
|
" <td>Brian Chesky</td>\n",
|
|||
|
" <td>11.5</td>\n",
|
|||
|
" <td>40</td>\n",
|
|||
|
" <td>United States</td>\n",
|
|||
|
" <td>Airbnb</td>\n",
|
|||
|
" <td>Technology</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>949</th>\n",
|
|||
|
" <td>913</td>\n",
|
|||
|
" <td>Zhong Ruonong & family</td>\n",
|
|||
|
" <td>3.3</td>\n",
|
|||
|
" <td>59</td>\n",
|
|||
|
" <td>China</td>\n",
|
|||
|
" <td>electronics</td>\n",
|
|||
|
" <td>Manufacturing</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>49</th>\n",
|
|||
|
" <td>50</td>\n",
|
|||
|
" <td>Miriam Adelson</td>\n",
|
|||
|
" <td>27.5</td>\n",
|
|||
|
" <td>76</td>\n",
|
|||
|
" <td>United States</td>\n",
|
|||
|
" <td>casinos</td>\n",
|
|||
|
" <td>Gambling & Casinos</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2511</th>\n",
|
|||
|
" <td>2448</td>\n",
|
|||
|
" <td>Lou Boliang</td>\n",
|
|||
|
" <td>1.1</td>\n",
|
|||
|
" <td>58</td>\n",
|
|||
|
" <td>United States</td>\n",
|
|||
|
" <td>pharmaceuticals</td>\n",
|
|||
|
" <td>Healthcare</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>520 rows × 7 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Rank Name Networth Age Country \\\n",
|
|||
|
"1593 1579 Guangming Fu & family 1.9 68 China \n",
|
|||
|
"196 197 Leon Black 10.0 70 United States \n",
|
|||
|
"239 235 Zong Qinghou 8.8 76 China \n",
|
|||
|
"2126 2076 Kurt Krieger 1.4 74 Germany \n",
|
|||
|
"1587 1579 Chen Kaichen 1.9 64 China \n",
|
|||
|
"... ... ... ... ... ... \n",
|
|||
|
"1778 1729 Jorge Perez 1.7 72 United States \n",
|
|||
|
"166 167 Brian Chesky 11.5 40 United States \n",
|
|||
|
"949 913 Zhong Ruonong & family 3.3 59 China \n",
|
|||
|
"49 50 Miriam Adelson 27.5 76 United States \n",
|
|||
|
"2511 2448 Lou Boliang 1.1 58 United States \n",
|
|||
|
"\n",
|
|||
|
" Source Industry \n",
|
|||
|
"1593 poultry Food & Beverage \n",
|
|||
|
"196 private equity Finance & Investments \n",
|
|||
|
"239 beverages Food & Beverage \n",
|
|||
|
"2126 furniture retailing Fashion & Retail \n",
|
|||
|
"1587 household chemicals Manufacturing \n",
|
|||
|
"... ... ... \n",
|
|||
|
"1778 real estate Real Estate \n",
|
|||
|
"166 Airbnb Technology \n",
|
|||
|
"949 electronics Manufacturing \n",
|
|||
|
"49 casinos Gambling & Casinos \n",
|
|||
|
"2511 pharmaceuticals Healthcare \n",
|
|||
|
"\n",
|
|||
|
"[520 rows x 7 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"'y_test'"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>above_average_networth</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1593</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>196</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>239</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2126</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1587</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1778</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>166</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>949</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>49</th>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2511</th>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>520 rows × 1 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" above_average_networth\n",
|
|||
|
"1593 0\n",
|
|||
|
"196 1\n",
|
|||
|
"239 1\n",
|
|||
|
"2126 0\n",
|
|||
|
"1587 0\n",
|
|||
|
"... ...\n",
|
|||
|
"1778 0\n",
|
|||
|
"166 1\n",
|
|||
|
"949 0\n",
|
|||
|
"49 1\n",
|
|||
|
"2511 0\n",
|
|||
|
"\n",
|
|||
|
"[520 rows x 1 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from typing import Tuple\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"from pandas import DataFrame\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"\n",
|
|||
|
"def split_into_train_test(\n",
|
|||
|
" df_input: DataFrame,\n",
|
|||
|
" target_colname: str = \"above_average_networth\", \n",
|
|||
|
" frac_train: float = 0.8,\n",
|
|||
|
" random_state: int = None,\n",
|
|||
|
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
|||
|
" \n",
|
|||
|
" if not (0 < frac_train < 1):\n",
|
|||
|
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
|
|||
|
" \n",
|
|||
|
" # Проверка наличия целевого признака\n",
|
|||
|
" if target_colname not in df_input.columns:\n",
|
|||
|
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
|
|||
|
" \n",
|
|||
|
" # Разделяем данные на признаки и целевую переменную\n",
|
|||
|
" X = df_input.drop(columns=[target_colname]) # Признаки\n",
|
|||
|
" y = df_input[[target_colname]] # Целевая переменная\n",
|
|||
|
"\n",
|
|||
|
" # Разделяем данные на обучающую и тестовую выборки\n",
|
|||
|
" X_train, X_test, y_train, y_test = train_test_split(\n",
|
|||
|
" X, y,\n",
|
|||
|
" test_size=(1.0 - frac_train),\n",
|
|||
|
" random_state=random_state\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" return X_train, X_test, y_train, y_test\n",
|
|||
|
"\n",
|
|||
|
"# Применение функции для разделения данных\n",
|
|||
|
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
|
|||
|
" df, \n",
|
|||
|
" target_colname=\"above_average_networth\", \n",
|
|||
|
" frac_train=0.8, \n",
|
|||
|
" random_state=42 \n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Для отображения результатов\n",
|
|||
|
"display(\"X_train\", X_train)\n",
|
|||
|
"display(\"y_train\", y_train)\n",
|
|||
|
"\n",
|
|||
|
"display(\"X_test\", X_test)\n",
|
|||
|
"display(\"y_test\", y_test)\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Формирование конвейера для решения задачи регрессии"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 70,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" Networth Age Country_Argentina Country_Australia \\\n",
|
|||
|
"0 20.092595 -1.079729 0.0 0.0 \n",
|
|||
|
"1 15.588775 -0.474496 0.0 0.0 \n",
|
|||
|
"2 14.368991 0.660314 0.0 0.0 \n",
|
|||
|
"3 11.647933 0.130736 0.0 0.0 \n",
|
|||
|
"4 10.615808 2.022087 0.0 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"2595 -0.362253 1.189893 0.0 0.0 \n",
|
|||
|
"2596 -0.362253 1.341201 0.0 0.0 \n",
|
|||
|
"2597 -0.362253 0.509006 0.0 0.0 \n",
|
|||
|
"2598 -0.362253 0.282044 0.0 0.0 \n",
|
|||
|
"2599 -0.362253 0.357698 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Country_Austria Country_Barbados Country_Belgium Country_Belize \\\n",
|
|||
|
"0 0.0 0.0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"2595 0.0 0.0 0.0 0.0 \n",
|
|||
|
"2596 0.0 0.0 0.0 0.0 \n",
|
|||
|
"2597 0.0 0.0 0.0 0.0 \n",
|
|||
|
"2598 0.0 0.0 0.0 0.0 \n",
|
|||
|
"2599 0.0 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Country_Brazil Country_Bulgaria ... Industry_Manufacturing \\\n",
|
|||
|
"0 0.0 0.0 ... 0.0 \n",
|
|||
|
"1 0.0 0.0 ... 0.0 \n",
|
|||
|
"2 0.0 0.0 ... 0.0 \n",
|
|||
|
"3 0.0 0.0 ... 0.0 \n",
|
|||
|
"4 0.0 0.0 ... 0.0 \n",
|
|||
|
"... ... ... ... ... \n",
|
|||
|
"2595 0.0 0.0 ... 0.0 \n",
|
|||
|
"2596 0.0 0.0 ... 0.0 \n",
|
|||
|
"2597 0.0 0.0 ... 0.0 \n",
|
|||
|
"2598 0.0 0.0 ... 0.0 \n",
|
|||
|
"2599 0.0 0.0 ... 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Media & Entertainment Industry_Metals & Mining \\\n",
|
|||
|
"0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 \n",
|
|||
|
"... ... ... \n",
|
|||
|
"2595 0.0 0.0 \n",
|
|||
|
"2596 0.0 0.0 \n",
|
|||
|
"2597 0.0 0.0 \n",
|
|||
|
"2598 0.0 0.0 \n",
|
|||
|
"2599 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Real Estate Industry_Service Industry_Sports \\\n",
|
|||
|
"0 0.0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"2595 0.0 0.0 0.0 \n",
|
|||
|
"2596 0.0 0.0 0.0 \n",
|
|||
|
"2597 0.0 0.0 0.0 \n",
|
|||
|
"2598 0.0 0.0 0.0 \n",
|
|||
|
"2599 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Industry_Technology Industry_Telecom Industry_diversified \\\n",
|
|||
|
"0 0.0 0.0 0.0 \n",
|
|||
|
"1 1.0 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 0.0 \n",
|
|||
|
"3 1.0 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 0.0 \n",
|
|||
|
"... ... ... ... \n",
|
|||
|
"2595 0.0 0.0 0.0 \n",
|
|||
|
"2596 0.0 0.0 0.0 \n",
|
|||
|
"2597 0.0 0.0 0.0 \n",
|
|||
|
"2598 0.0 0.0 0.0 \n",
|
|||
|
"2599 0.0 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Networth_per_Age \n",
|
|||
|
"0 -18.608929 \n",
|
|||
|
"1 -32.853309 \n",
|
|||
|
"2 21.760834 \n",
|
|||
|
"3 89.095063 \n",
|
|||
|
"4 5.249926 \n",
|
|||
|
"... ... \n",
|
|||
|
"2595 -0.304441 \n",
|
|||
|
"2596 -0.270096 \n",
|
|||
|
"2597 -0.711686 \n",
|
|||
|
"2598 -1.284383 \n",
|
|||
|
"2599 -1.012732 \n",
|
|||
|
"\n",
|
|||
|
"[2600 rows x 988 columns]\n",
|
|||
|
"(2600, 988)\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn.base import BaseEstimator, TransformerMixin\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.preprocessing import OneHotEncoder\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor \n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"from sklearn.pipeline import make_pipeline\n",
|
|||
|
"\n",
|
|||
|
"class ForbesBillionairesFeatures(BaseEstimator, TransformerMixin): \n",
|
|||
|
" def __init__(self):\n",
|
|||
|
" pass\n",
|
|||
|
"\n",
|
|||
|
" def fit(self, X, y=None):\n",
|
|||
|
" return self\n",
|
|||
|
"\n",
|
|||
|
" def transform(self, X, y=None):\n",
|
|||
|
" X[\"Networth_per_Age\"] = X[\"Networth\"] / X[\"Age\"]\n",
|
|||
|
" return X\n",
|
|||
|
"\n",
|
|||
|
" def get_feature_names_out(self, features_in):\n",
|
|||
|
" return np.append(features_in, [\"Networth_per_Age\"], axis=0) \n",
|
|||
|
"\n",
|
|||
|
"# Определите признаки для вашей задачи\n",
|
|||
|
"columns_to_drop = [\"Rank \", \"Name\"] \n",
|
|||
|
"num_columns = [\"Networth\", \"Age\"] \n",
|
|||
|
"cat_columns = [\"Country\", \"Source\", \"Industry\"]\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование числовых признаков\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование категориальных признаков\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Формирование конвейера\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\" \n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"drop_columns = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\",\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Окончательный конвейер\n",
|
|||
|
"pipeline_end = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" (\"drop_columns\", drop_columns),\n",
|
|||
|
" (\"custom_features\", ForbesBillionairesFeatures()), # Добавляем custom_features\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"df = pd.read_csv(\"..//static//csv//Forbes Billionaires.csv\")\n",
|
|||
|
"\n",
|
|||
|
"# Создаем целевой признак\n",
|
|||
|
"average_networth = df['Networth'].mean()\n",
|
|||
|
"df['above_average_networth'] = (df['Networth'] > average_networth).astype(int)\n",
|
|||
|
"\n",
|
|||
|
"# Подготовка данных\n",
|
|||
|
"X = df.drop('above_average_networth', axis=1)\n",
|
|||
|
"y = df['above_average_networth'].values.ravel()\n",
|
|||
|
"\n",
|
|||
|
"# Применение конвейера\n",
|
|||
|
"X_processed = pipeline_end.fit_transform(X)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод\n",
|
|||
|
"print(X_processed)\n",
|
|||
|
"print(X_processed.shape)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Формирование набора моделей для регрессии"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 71,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1339: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True)\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1339: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True)\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1339: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True)\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1339: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True)\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1339: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Random Forest: Mean Score = 0.9999449765688064, Standard Deviation = 0.00010860474979394001\n",
|
|||
|
"Linear Regression: Mean Score = -5.286122247142867e+21, Standard Deviation = 9.978968848315854e+21\n",
|
|||
|
"Gradient Boosting: Mean Score = 0.9999999992916644, Standard Deviation = 2.7301021406313204e-12\n",
|
|||
|
"Support Vector Regression: Mean Score = 0.6826855358064324, Standard Deviation = 0.020395315184745886\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.linear_model import LinearRegression\n",
|
|||
|
"from sklearn.ensemble import GradientBoostingRegressor\n",
|
|||
|
"from sklearn.svm import SVR\n",
|
|||
|
"from sklearn.model_selection import cross_val_score\n",
|
|||
|
"\n",
|
|||
|
"def train_multiple_models(X, y, models):\n",
|
|||
|
" results = {}\n",
|
|||
|
" for model_name, model in models.items():\n",
|
|||
|
" # Создаем конвейер для каждой модели\n",
|
|||
|
" model_pipeline = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" (\"drop_columns\", drop_columns),\n",
|
|||
|
" (\"model\", model) # Используем текущую модель\n",
|
|||
|
" ]\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" # Обучаем модель и вычисляем кросс-валидацию\n",
|
|||
|
" scores = cross_val_score(model_pipeline, X, y, cv=5) # 5-кратная кросс-валидация\n",
|
|||
|
" results[model_name] = {\n",
|
|||
|
" \"mean_score\": scores.mean(),\n",
|
|||
|
" \"std_dev\": scores.std()\n",
|
|||
|
" }\n",
|
|||
|
" \n",
|
|||
|
" return results\n",
|
|||
|
"\n",
|
|||
|
"models = {\n",
|
|||
|
" \"Random Forest\": RandomForestRegressor(),\n",
|
|||
|
" \"Linear Regression\": LinearRegression(),\n",
|
|||
|
" \"Gradient Boosting\": GradientBoostingRegressor(),\n",
|
|||
|
" \"Support Vector Regression\": SVR()\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"results = train_multiple_models(X_train, y_train, models)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод результатов\n",
|
|||
|
"for model_name, scores in results.items():\n",
|
|||
|
" print(f\"{model_name}: Mean Score = {scores['mean_score']}, Standard Deviation = {scores['std_dev']}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"- Random Forest показала очень высокое среднее значение, близкое к 1, что указывает на ее высокую точность в предсказании. Стандартное отклонение также очень низкое, что говорит о стабильности модели.\n",
|
|||
|
"- Линейная регрессия показала очень низкое среднее значение с огромным отрицательным числом, что указывает на ее неэффективность в данной задаче. Стандартное отклонение также очень высокое, что говорит о нестабильности модели.\n",
|
|||
|
"- Gradient Boosting показала практически идеальное среднее значение, близкое к 1, что указывает на ее высокую точность в предсказании. Стандартное отклонение практически равно нулю, что говорит о чрезвычайной стабильности модели.\n",
|
|||
|
"- Support Vector Regression показала среднее значение около 0.68, что указывает на ее умеренную точность в предсказании. Стандартное отклонение относительно низкое, что говорит о стабильности модели, но она все же уступает Random Forest и Gradient Boosting."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обучение моделей на обучающем наборе данных и оценка на тестовом для регрессии"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 72,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: logistic\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"MSE (train): 0.0125\n",
|
|||
|
"MSE (test): 0.04038461538461539\n",
|
|||
|
"MAE (train): 0.0125\n",
|
|||
|
"MAE (test): 0.04038461538461539\n",
|
|||
|
"R2 (train): 0.9275415718173158\n",
|
|||
|
"R2 (test): 0.7776148582600195\n",
|
|||
|
"STD (train): 0.11110243021644485\n",
|
|||
|
"STD (test): 0.19685959012669935\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: ridge\n",
|
|||
|
"MSE (train): 0.004326923076923077\n",
|
|||
|
"MSE (test): 0.013461538461538462\n",
|
|||
|
"MAE (train): 0.004326923076923077\n",
|
|||
|
"MAE (test): 0.013461538461538462\n",
|
|||
|
"R2 (train): 0.9749182363983017\n",
|
|||
|
"R2 (test): 0.9258716194200065\n",
|
|||
|
"STD (train): 0.0656368860749005\n",
|
|||
|
"STD (test): 0.11588034534756023\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: decision_tree\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"MSE (train): 0.0\n",
|
|||
|
"MSE (test): 0.0\n",
|
|||
|
"MAE (train): 0.0\n",
|
|||
|
"MAE (test): 0.0\n",
|
|||
|
"R2 (train): 1.0\n",
|
|||
|
"R2 (test): 1.0\n",
|
|||
|
"STD (train): 0.0\n",
|
|||
|
"STD (test): 0.0\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: knn\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"MSE (train): 0.09278846153846154\n",
|
|||
|
"MSE (test): 0.15384615384615385\n",
|
|||
|
"MAE (train): 0.09278846153846154\n",
|
|||
|
"MAE (test): 0.15384615384615385\n",
|
|||
|
"R2 (train): 0.4621355138746903\n",
|
|||
|
"R2 (test): 0.1528185076572175\n",
|
|||
|
"STD (train): 0.29276240884468824\n",
|
|||
|
"STD (test): 0.3684085396282311\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: naive_bayes\n",
|
|||
|
"MSE (train): 0.37740384615384615\n",
|
|||
|
"MSE (test): 0.6096153846153847\n",
|
|||
|
"MAE (train): 0.37740384615384615\n",
|
|||
|
"MAE (test): 0.6096153846153847\n",
|
|||
|
"R2 (train): -1.1876871585925808\n",
|
|||
|
"R2 (test): -2.3569566634082757\n",
|
|||
|
"STD (train): 0.4847372309428379\n",
|
|||
|
"STD (test): 0.5672229402142737\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: gradient_boosting\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"MSE (train): 0.0\n",
|
|||
|
"MSE (test): 0.0\n",
|
|||
|
"MAE (train): 0.0\n",
|
|||
|
"MAE (test): 0.0\n",
|
|||
|
"R2 (train): 1.0\n",
|
|||
|
"R2 (test): 1.0\n",
|
|||
|
"STD (train): 0.0\n",
|
|||
|
"STD (test): 0.0\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: random_forest\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"MSE (train): 0.0\n",
|
|||
|
"MSE (test): 0.0\n",
|
|||
|
"MAE (train): 0.0\n",
|
|||
|
"MAE (test): 0.0\n",
|
|||
|
"R2 (train): 1.0\n",
|
|||
|
"R2 (test): 1.0\n",
|
|||
|
"STD (train): 0.0\n",
|
|||
|
"STD (test): 0.0\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Model: mlp\n",
|
|||
|
"MSE (train): 0.06778846153846153\n",
|
|||
|
"MSE (test): 0.12692307692307692\n",
|
|||
|
"MAE (train): 0.06778846153846153\n",
|
|||
|
"MAE (test): 0.12692307692307692\n",
|
|||
|
"R2 (train): 0.6070523702400588\n",
|
|||
|
"R2 (test): 0.30107526881720437\n",
|
|||
|
"STD (train): 0.2521427220700598\n",
|
|||
|
"STD (test): 0.3370600353877945\n",
|
|||
|
"----------------------------------------\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"\n",
|
|||
|
"# Проверка наличия необходимых переменных\n",
|
|||
|
"if 'class_models' not in locals():\n",
|
|||
|
" raise ValueError(\"class_models is not defined\")\n",
|
|||
|
"if 'X_train' not in locals() or 'X_test' not in locals() or 'y_train' not in locals() or 'y_test' not in locals():\n",
|
|||
|
" raise ValueError(\"Train/test data is not defined\")\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"y_train = np.ravel(y_train) \n",
|
|||
|
"y_test = np.ravel(y_test) \n",
|
|||
|
"\n",
|
|||
|
"# Инициализация списка для хранения результатов\n",
|
|||
|
"results = []\n",
|
|||
|
"\n",
|
|||
|
"# Проход по моделям и оценка их качества\n",
|
|||
|
"for model_name in class_models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
" \n",
|
|||
|
" # Извлечение модели из словаря\n",
|
|||
|
" model = class_models[model_name][\"model\"]\n",
|
|||
|
" \n",
|
|||
|
" # Создание пайплайна\n",
|
|||
|
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
|
|||
|
" \n",
|
|||
|
" # Обучение модели\n",
|
|||
|
" model_pipeline.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
" # Предсказание для обучающей и тестовой выборки\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train)\n",
|
|||
|
" y_test_predict = model_pipeline.predict(X_test)\n",
|
|||
|
"\n",
|
|||
|
" # Сохранение пайплайна и предсказаний\n",
|
|||
|
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
|||
|
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
|||
|
"\n",
|
|||
|
" # Вычисление метрик для регрессии\n",
|
|||
|
" class_models[model_name][\"MSE_train\"] = metrics.mean_squared_error(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"MSE_test\"] = metrics.mean_squared_error(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"MAE_train\"] = metrics.mean_absolute_error(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"MAE_test\"] = metrics.mean_absolute_error(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"R2_train\"] = metrics.r2_score(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_predict)\n",
|
|||
|
"\n",
|
|||
|
" # Дополнительные метрики\n",
|
|||
|
" class_models[model_name][\"STD_train\"] = np.std(y_train - y_train_predict)\n",
|
|||
|
" class_models[model_name][\"STD_test\"] = np.std(y_test - y_test_predict)\n",
|
|||
|
"\n",
|
|||
|
" # Вывод результатов для текущей модели\n",
|
|||
|
" print(f\"MSE (train): {class_models[model_name]['MSE_train']}\")\n",
|
|||
|
" print(f\"MSE (test): {class_models[model_name]['MSE_test']}\")\n",
|
|||
|
" print(f\"MAE (train): {class_models[model_name]['MAE_train']}\")\n",
|
|||
|
" print(f\"MAE (test): {class_models[model_name]['MAE_test']}\")\n",
|
|||
|
" print(f\"R2 (train): {class_models[model_name]['R2_train']}\")\n",
|
|||
|
" print(f\"R2 (test): {class_models[model_name]['R2_test']}\")\n",
|
|||
|
" print(f\"STD (train): {class_models[model_name]['STD_train']}\")\n",
|
|||
|
" print(f\"STD (test): {class_models[model_name]['STD_test']}\")\n",
|
|||
|
" print(\"-\" * 40) # Разделитель для разных моделей"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Пример использования обученной модели (конвейера регрессии) для предсказания"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 74,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model: RandomForest\n",
|
|||
|
"MSE (train): 24.028673442957558\n",
|
|||
|
"MSE (test): 68.96006650623248\n",
|
|||
|
"MAE (train): 1.548185999451937\n",
|
|||
|
"MAE (test): 3.372747412240537\n",
|
|||
|
"R2 (train): 0.8231149198653249\n",
|
|||
|
"R2 (test): -1.9013866015383956\n",
|
|||
|
"----------------------------------------\n",
|
|||
|
"Прогнозируемое чистое состояние: 1.3689999999999998\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1, 2] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor \n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.preprocessing import OneHotEncoder\n",
|
|||
|
"\n",
|
|||
|
"# 1. Загрузка данных\n",
|
|||
|
"data = pd.read_csv(\"..//static//csv//Forbes Billionaires.csv\") \n",
|
|||
|
"\n",
|
|||
|
"# 2. Подготовка данных для прогноза\n",
|
|||
|
"average_networth = data['Networth'].mean()\n",
|
|||
|
"data['above_average_networth'] = (data['Networth'] > average_networth).astype(int) \n",
|
|||
|
"\n",
|
|||
|
"# Предикторы и целевая переменная\n",
|
|||
|
"X = data.drop('Networth', axis=1) \n",
|
|||
|
"y = data['Networth']\n",
|
|||
|
"\n",
|
|||
|
"# 3. Инициализация модели и пайплайна\n",
|
|||
|
"class_models = {\n",
|
|||
|
" \"RandomForest\": {\n",
|
|||
|
" \"model\": RandomForestRegressor(n_estimators=100, random_state=42),\n",
|
|||
|
" }\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Предобработка признаков\n",
|
|||
|
"num_columns = ['Age']\n",
|
|||
|
"cat_columns = ['Country', 'Source', 'Industry']\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование числовых признаков\n",
|
|||
|
"num_transformer = Pipeline(steps=[\n",
|
|||
|
" ('imputer', SimpleImputer(strategy='median')),\n",
|
|||
|
" ('scaler', StandardScaler())\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Преобразование категориальных признаков\n",
|
|||
|
"cat_transformer = Pipeline(steps=[\n",
|
|||
|
" ('imputer', SimpleImputer(strategy='constant', fill_value='unknown')),\n",
|
|||
|
" ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False, drop=\"first\"))\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Создание конвейера предобработки\n",
|
|||
|
"preprocessor = ColumnTransformer(\n",
|
|||
|
" transformers=[\n",
|
|||
|
" ('num', num_transformer, num_columns),\n",
|
|||
|
" ('cat', cat_transformer, cat_columns)\n",
|
|||
|
" ])\n",
|
|||
|
"\n",
|
|||
|
"# Создание конвейера модели\n",
|
|||
|
"pipeline_end = Pipeline(steps=[\n",
|
|||
|
" ('preprocessor', preprocessor),\n",
|
|||
|
" # ('model', model) # Модель добавляется в цикле\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"results = []\n",
|
|||
|
"\n",
|
|||
|
"# 4. Обучение модели и оценка\n",
|
|||
|
"for model_name in class_models.keys():\n",
|
|||
|
" print(f\"Model: {model_name}\")\n",
|
|||
|
"\n",
|
|||
|
" model = class_models[model_name][\"model\"]\n",
|
|||
|
" model_pipeline = Pipeline(steps=[\n",
|
|||
|
" ('preprocessor', preprocessor),\n",
|
|||
|
" ('model', model)\n",
|
|||
|
" ])\n",
|
|||
|
"\n",
|
|||
|
" # Разделение данных\n",
|
|||
|
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
|
|||
|
"\n",
|
|||
|
" # Обучение модели\n",
|
|||
|
" model_pipeline.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
" # Предсказание\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train)\n",
|
|||
|
" y_test_predict = model_pipeline.predict(X_test)\n",
|
|||
|
"\n",
|
|||
|
" # Сохранение результатов\n",
|
|||
|
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
|||
|
"\n",
|
|||
|
" # Вычисление метрик\n",
|
|||
|
" class_models[model_name][\"MSE_train\"] = metrics.mean_squared_error(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"MSE_test\"] = metrics.mean_squared_error(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"MAE_train\"] = metrics.mean_absolute_error(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"MAE_test\"] = metrics.mean_absolute_error(y_test, y_test_predict)\n",
|
|||
|
" class_models[model_name][\"R2_train\"] = metrics.r2_score(y_train, y_train_predict)\n",
|
|||
|
" class_models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_predict)\n",
|
|||
|
"\n",
|
|||
|
" # Вывод результатов\n",
|
|||
|
" print(f\"MSE (train): {class_models[model_name]['MSE_train']}\")\n",
|
|||
|
" print(f\"MSE (test): {class_models[model_name]['MSE_test']}\")\n",
|
|||
|
" print(f\"MAE (train): {class_models[model_name]['MAE_train']}\")\n",
|
|||
|
" print(f\"MAE (test): {class_models[model_name]['MAE_test']}\")\n",
|
|||
|
" print(f\"R2 (train): {class_models[model_name]['R2_train']}\")\n",
|
|||
|
" print(f\"R2 (test): {class_models[model_name]['R2_test']}\")\n",
|
|||
|
" print(\"-\" * 40)\n",
|
|||
|
"\n",
|
|||
|
"# Прогнозирование чистого состояния для нового миллиардера\n",
|
|||
|
"new_billionaire_data = pd.DataFrame({\n",
|
|||
|
" 'Age': [50],\n",
|
|||
|
" 'Country': ['USA'],\n",
|
|||
|
" 'Source': ['Self Made'], \n",
|
|||
|
" 'Industry': ['Technology'], \n",
|
|||
|
"})\n",
|
|||
|
"\n",
|
|||
|
"predicted_networth = model_pipeline.predict(new_billionaire_data)\n",
|
|||
|
"print(f\"Прогнозируемое чистое состояние: {predicted_networth[0]}\")\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Подбор гиперпараметров методом поиска по сетке"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 84,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Fitting 3 folds for each of 36 candidates, totalling 108 fits\n",
|
|||
|
"Лучшие параметры: {'max_depth': 30, 'min_samples_split': 2, 'n_estimators': 100}\n",
|
|||
|
"Лучший результат (MSE): 5.88542132388105\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"\n",
|
|||
|
"# Удаление строк с пропущенными значениями (если необходимо)\n",
|
|||
|
"df = df.dropna()\n",
|
|||
|
"\n",
|
|||
|
"# Создание целевой переменной (Networth)\n",
|
|||
|
"target = df['Networth']\n",
|
|||
|
"\n",
|
|||
|
"# Удаление целевой переменной из исходных данных\n",
|
|||
|
"features = df.drop(columns=['Networth'])\n",
|
|||
|
"\n",
|
|||
|
"# Удаление столбцов, которые не будут использоваться (например, имена)\n",
|
|||
|
"features = features.drop(columns=['Name'])\n",
|
|||
|
"\n",
|
|||
|
"# Определение столбцов для обработки\n",
|
|||
|
"num_columns = features.select_dtypes(include=['number']).columns\n",
|
|||
|
"cat_columns = features.select_dtypes(include=['object']).columns\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг числовых столбцов\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\") # Используем медиану для заполнения пропущенных значений в числовых столбцах\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг категориальных столбцов\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\") # Используем 'unknown' для заполнения пропущенных значений в категориальных столбцах\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Объединение препроцессинга\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\"\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Создание финального пайплайна\n",
|
|||
|
"pipeline_end = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Разделение данных на обучающую и тестовую выборки\n",
|
|||
|
"X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)\n",
|
|||
|
"\n",
|
|||
|
"# Применение пайплайна к данным\n",
|
|||
|
"X_train_processed = pipeline_end.fit_transform(X_train)\n",
|
|||
|
"X_test_processed = pipeline_end.transform(X_test)\n",
|
|||
|
"\n",
|
|||
|
"# 2. Создание и настройка модели случайного леса\n",
|
|||
|
"model = RandomForestRegressor()\n",
|
|||
|
"\n",
|
|||
|
"# Установка параметров для поиска по сетке\n",
|
|||
|
"param_grid = {\n",
|
|||
|
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
|
|||
|
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
|
|||
|
" 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# 3. Подбор гиперпараметров с помощью Grid Search\n",
|
|||
|
"grid_search = GridSearchCV(estimator=model, param_grid=param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели на тренировочных данных\n",
|
|||
|
"grid_search.fit(X_train_processed, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# 4. Результаты подбора гиперпараметров\n",
|
|||
|
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
|
|||
|
"print(\"Лучший результат (MSE):\", -grid_search.best_score_) # Меняем знак, так как берем отрицательное значение среднеквадратичной ошибки"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Обучение модели с новыми гиперпараметрами и сравнение новых и старых данных"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 87,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"c:\\Users\\Admin\\Desktop\\5 semestr\\mii\\AIM-PIbd-32-Safiulova-K-N\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Fitting 3 folds for each of 36 candidates, totalling 108 fits\n",
|
|||
|
"Старые параметры: {'max_depth': 20, 'min_samples_split': 2, 'n_estimators': 50}\n",
|
|||
|
"Лучший результат (MSE) на старых параметрах: 5.760387482085847\n",
|
|||
|
"\n",
|
|||
|
"Новые параметры: {'max_depth': 10, 'min_samples_split': 10, 'n_estimators': 200}\n",
|
|||
|
"Лучший результат (MSE) на новых параметрах: 13.643983185514095\n",
|
|||
|
"Среднеквадратическая ошибка (MSE) на тестовых данных: 0.024952019817877404\n",
|
|||
|
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 0.15796208348169316\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0kAAAHWCAYAAACi1sL/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABj9klEQVR4nO3dd3gU1fv38c+GhPRGCRBK6DWIFAuggKIERKqCEJAqKiJFioAKISBGRIoKoqAColQFVCx8ld4ldGlSQpFeE5KQBJJ5/uDJ/nZJArthUwjv13XtRebMmTP3DJNM7pwzZ0yGYRgCAAAAAEiSnHI6AAAAAADITUiSAAAAAMACSRIAAAAAWCBJAgAAAAALJEkAAAAAYIEkCQAAAAAskCQBAAAAgAWSJAAAAACwQJIEAAAAABZIkgAAAADAAkkScJ84cuSIXnvtNZUtW1Zubm7y8fFR/fr19cknn+j69es5Hd4DY/Xq1TKZTDKZTPruu+/SrVO/fn2ZTCYFBwdblSclJemTTz5RzZo15ePjIz8/P1WrVk2vvvqqDhw4YK43a9Ys8z7S+2zevDlLjxEAgAedc04HAODufv31V7Vr106urq7q0qWLgoODlZSUpPXr12vIkCHau3evpk+fntNhPlDc3Nw0d+5cde7c2ar82LFj2rhxo9zc3NJs88ILL+j3339Xx44d1atXL924cUMHDhzQsmXLVK9ePVWuXNmq/ujRo1WmTJk07ZQvX96xBwMAAKyQJAG5XFRUlDp06KCgoCCtXLlSxYoVM6/r06ePDh8+rF9//TUHI3wwPffcc/r555918eJFFSpUyFw+d+5cFSlSRBUqVNCVK1fM5Vu3btWyZcs0duxYvfPOO1ZtTZkyRVevXk2zj2bNmqlOnTpZdgwAACB9DLcDcrmPPvpIsbGx+vrrr60SpFTly5dX//79zcsmk0lvvvmmvv/+e1WqVElubm6qXbu21q5da7Xd8ePH9cYbb6hSpUpyd3dXwYIF1a5dOx07dsyq3u1Dvzw8PFS9enV99dVXVvW6desmLy+vNPH98MMPMplMWr16tVX5li1b1LRpU/n6+srDw0MNGzbUhg0brOqMGjVKJpNJFy9etCqPjIyUyWTSrFmzrPZfunRpq3onT56Uu7u7TCZTmuP6/fff9eSTT8rT01Pe3t5q3ry59u7dmyb+jLRq1Uqurq5atGiRVfncuXPVvn175cuXz6r8yJEjkm4Nxbtdvnz5VLBgQZv3bYtjx45lOFzv9nMhSY0aNUq3ruU5lqRp06YpODhYHh4eVvV++OGHu8Z06tQp9ezZU4GBgXJ1dVWZMmXUu3dvJSUl3XWIoWUsu3fvVrdu3cxDT4sWLaoePXro0qVLVvtLvX4OHDig9u3by8fHRwULFlT//v2VkJBgVTf1+yYjqfGlnruVK1fKyclJI0eOtKo3d+5cmUwmTZs27Y7nolGjRmrUqJFV2datW83HejeNGjVKM5xTkj7++ON0/48///xzVatWTa6urgoMDFSfPn3SJOa3XwOFChVS8+bN9c8//1jVy4lzdafrwvJYf/rpJzVv3tx8jZUrV05jxoxRcnJymjaDg4O1bds21atXT+7u7ipTpoy++OILq3pJSUkaOXKkateuLV9fX3l6eurJJ5/UqlWrrOpZfr8tXbrUal1CQoL8/f1lMpn08ccfW607deqUevTooSJFisjV1VXVqlXTN998Y15vObw3o8+oUaMk2Xe937x5U2PGjFG5cuXk6uqq0qVL65133lFiYqJVvdKlS5v34+TkpKJFi+qll17SiRMn7vh/BuQF9CQBudwvv/yismXLql69ejZvs2bNGi1YsED9+vWTq6urPv/8czVt2lR///23+RerrVu3auPGjerQoYNKlCihY8eOadq0aWrUqJH27dsnDw8PqzYnTZqkQoUKKSYmRt9884169eql0qVL65lnnrH7mFauXKlmzZqpdu3aCgsLk5OTk2bOnKmnn35a69at06OPPmp3m+kZOXJkml8OJGnOnDnq2rWrQkJCNG7cOMXHx2vatGl64okntGPHjjTJVno8PDzUqlUrzZs3T71795Yk7dq1S3v37tVXX32l3bt3W9UPCgqSJH3//feqX7++nJ3v/uM3Ojo6TYJoMpnsSqg6duyo5557TpL022+/ad68eRnWrVy5st59911J0sWLF/XWW29ZrV+wYIHeeOMNNWrUSH379pWnp6f279+vDz744K5xnD59Wo8++qiuXr2qV199VZUrV9apU6f0ww8/KD4+Xg0aNNCcOXPM9ceOHStJ5ngkmb8H/vzzTx09elTdu3dX0aJFzcNN9+7dq82bN6dJMtq3b6/SpUsrIiJCmzdv1qeffqorV67o22+/vWvcGXn66af1xhtvKCIiQq1bt1atWrV05swZ9e3bV88884xef/11u9scOnRopuO5k1GjRik8PFzPPPOMevfurYMHD2ratGnaunWrNmzYIBcXF3Pd1GvAMAwdOXJEEydO1HPPPXdPvxQ74lyVKFFCERERVmXpXc+zZs2Sl5eXBg4cKC8vL61cuVIjR45UTEyMxo8fb1X3ypUreu6559S+fXt17NhRCxcuVO/evZU/f3716NFDkhQTE6OvvvrKPET22rVr+vrrrxUSEqK///5bDz/8sFWbbm5umjlzplq3bm0uW7x4cbo/h86dO6fHH3/cnHQWLlxYv//+u3r27KmYmBgNGDBAVapUsfq+mD59uvbv369JkyaZyx566CGrdm253l955RXNnj1bL774ogYNGqQtW7YoIiJC+/fv15IlS6zae/LJJ/Xqq68qJSVF//zzjyZPnqzTp09r3bp1aY4JyFMMALlWdHS0Iclo1aqVzdtIMiQZkZGR5rLjx48bbm5uRps2bcxl8fHxabbdtGmTIcn49ttvzWUzZ840JBlRUVHmsn///deQZHz00Ufmsq5duxqenp5p2ly0aJEhyVi1apVhGIaRkpJiVKhQwQgJCTFSUlKs4ilTpozx7LPPmsvCwsIMScaFCxes2ty6dashyZg5c6bV/oOCgszL//zzj+Hk5GQ0a9bMKv5r164Zfn5+Rq9evazaPHv2rOHr65um/HarVq0yJBmLFi0yli1bZphMJuPEiROGYRjGkCFDjLJlyxqGYRgNGzY0qlWrZt4uJSXFaNiwoSHJKFKkiNGxY0dj6tSpxvHjx9PsI/Wcp/dxdXW9Y3ypUv+PPv74Y3PZ+PHj0/xfpqpfv77x1FNPmZejoqLSnOOOHTsafn5+xvXr19M9H3fSpUsXw8nJydi6dWuadZbXQaqGDRsaDRs2TLet9K7defPmGZKMtWvXmstSr5+WLVta1X3jjTcMScauXbvMZZKMPn36ZBh/et8HcXFxRvny5Y1q1aoZCQkJRvPmzQ0fH590/0/vdny//fabIclo2rSpYcut+fbrK9Xt/8fnz5838ufPbzRp0sRITk4215syZYohyfjmm28yjMkwDOOdd94xJBnnz583l+XEubLlWA0j/WvjtddeMzw8PIyEhASrNiUZEyZMMJclJiYaDz/8sBEQEGAkJSUZhmEYN2/eNBITE63au3LlilGkSBGjR48e5rLU75eOHTsazs7OxtmzZ83rGjdubISGhhqSjPHjx5vLe/bsaRQrVsy4ePGiVfsdOnQwfH190z2W23/OWbL1et+5c6chyXjllVes6g0ePNiQZKxcudJcFhQUZHTt2tWqXmhoqOHh4ZFuDEBewnA7IBeLiYmRJHl7e9u1Xd26dVW7dm3zcqlSpdSqVSstX77cPOzE3d3dvP7GjRu6dOmSypcvLz8/P23fvj1Nm1euXNHFixd19OhRTZo0Sfny5VPDhg3T1Lt48aLV59q1a1brd+7cqUOHDik0NFSXLl0y14uLi1Pjxo21du1apaSkWG1z+fJlqzajo6Pveg6GDx+uWrVqqV27dlblf/75p65evaqOHTtatZkvXz499thjaYbR3EmTJk1UoEABzZ8/X4ZhaP78+erYsWO6dU0mk5YvX673339f/v7+mjdvnvr06aOgoCC99NJL6T6TNHXqVP35559Wn99
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1000x500 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor\n",
|
|||
|
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Определение столбцов для обработки\n",
|
|||
|
"num_columns = features.select_dtypes(include=['number']).columns\n",
|
|||
|
"cat_columns = features.select_dtypes(include=['object']).columns\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг числовых столбцов\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\") # Используем медиану для заполнения пропущенных значений в числовых столбцах\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Препроцессинг категориальных столбцов\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\") # Используем 'unknown' для заполнения пропущенных значений в категориальных столбцах\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"preprocessing_cat = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Объединение препроцессинга\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\"\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Создание финального пайплайна\n",
|
|||
|
"pipeline_end = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Разделение данных на обучающую и тестовую выборки\n",
|
|||
|
"X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)\n",
|
|||
|
"\n",
|
|||
|
"# Применение пайплайна к данным\n",
|
|||
|
"X_train_processed = pipeline_end.fit_transform(X_train)\n",
|
|||
|
"X_test_processed = pipeline_end.transform(X_test)\n",
|
|||
|
"\n",
|
|||
|
"# 1. Настройка параметров для старых значений\n",
|
|||
|
"old_param_grid = {\n",
|
|||
|
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
|
|||
|
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
|
|||
|
" 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Подбор гиперпараметров с помощью Grid Search для старых параметров\n",
|
|||
|
"old_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
|
|||
|
" param_grid=old_param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели на тренировочных данных\n",
|
|||
|
"old_grid_search.fit(X_train_processed, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# 2. Результаты подбора для старых параметров\n",
|
|||
|
"old_best_params = old_grid_search.best_params_\n",
|
|||
|
"old_best_mse = -old_grid_search.best_score_ # Меняем знак, так как берем отрицательное значение MSE\n",
|
|||
|
"\n",
|
|||
|
"# 3. Настройка параметров для новых значений\n",
|
|||
|
"new_param_grid = {\n",
|
|||
|
" 'n_estimators': [200],\n",
|
|||
|
" 'max_depth': [10],\n",
|
|||
|
" 'min_samples_split': [10]\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Подбор гиперпараметров с помощью Grid Search для новых параметров\n",
|
|||
|
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
|
|||
|
" param_grid=new_param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', cv=2)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели на тренировочных данных\n",
|
|||
|
"new_grid_search.fit(X_train_processed, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# 4. Результаты подбора для новых параметров\n",
|
|||
|
"new_best_params = new_grid_search.best_params_\n",
|
|||
|
"new_best_mse = -new_grid_search.best_score_ # Меняем знак, так как берем отрицательное значение MSE\n",
|
|||
|
"\n",
|
|||
|
"# 5. Обучение модели с лучшими параметрами для новых значений\n",
|
|||
|
"model_best = RandomForestRegressor(**new_best_params)\n",
|
|||
|
"model_best.fit(X_train_processed, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# Прогнозирование на тестовой выборке\n",
|
|||
|
"y_pred = model_best.predict(X_test_processed)\n",
|
|||
|
"\n",
|
|||
|
"# Оценка производительности модели\n",
|
|||
|
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
|
|||
|
"rmse = np.sqrt(mse)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод результатов\n",
|
|||
|
"print(\"Старые параметры:\", old_best_params)\n",
|
|||
|
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
|
|||
|
"print(\"\\nНовые параметры:\", new_best_params)\n",
|
|||
|
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
|
|||
|
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
|
|||
|
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)\n",
|
|||
|
"\n",
|
|||
|
"# Визуализация ошибок\n",
|
|||
|
"plt.figure(figsize=(10, 5))\n",
|
|||
|
"plt.bar(['Старые параметры', 'Новые параметры'], [old_best_mse, new_best_mse], color=['blue', 'orange'])\n",
|
|||
|
"plt.xlabel('Подбор параметров')\n",
|
|||
|
"plt.ylabel('Среднеквадратическая ошибка (MSE)')\n",
|
|||
|
"plt.title('Сравнение MSE для старых и новых параметров')\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Модель, обученная на новых параметрах, показала худший результат (MSE) на кросс-валидации, что указывает на ее меньшую точность по сравнению с моделью, обученной на старых параметрах. Однако, MSE на тестовых данных одинакова для обеих моделей, что говорит о том, что обе модели имеют одинаковую производительность на тестовых данных."
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "aimenv",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.12.5"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|