2024-11-16 21:46:44 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Датасет: [Tesla Insider Trading](https://www.kaggle.com/datasets/ilyaryabov/tesla-insider-trading).\n",
"\n",
"### Описание датасета:\n",
"Датасет представляет собой выборку операций с ценными бумагами компании Tesla, совершённых инсайдерами, и является частью более крупного проекта \"Insider Trading S&P500 – Inside Info\". Данные охватывают транзакции с участием крупных акционеров и должностных лиц компании, включая такие операции, как покупка, продажа и опционы, начиная с 10 ноября 2021 года и до 27 июля 2022 года.\n",
"\n",
"---\n",
"\n",
"### Анализ сведений:\n",
"**Проблемная область:**\n",
"Проблемная область данного датасета касается анализа инсайдерских сделок в публичных компаниях, а также их влияния на ценообразование акций. Инсайдерские транзакции, совершаемые людьми с доступом к непубличной информации (такими как руководители, крупные акционеры или члены совета директоров), могут быть индикаторами будущих изменений стоимости акций. Исследование таких транзакций помогает понять, как информация внутри компании отражается в действиях ключевых участников, и может выявить паттерны поведения, которые влияют на рынки.\n",
"\n",
"**Актуальность:**\n",
"Анализ инсайдерских сделок становится особенно важным в условиях высокой волатильности рынка и неопределенности. Инвесторы, аналитики и компании используют такие данные, чтобы лучше понимать сигналы от крупных акционеров и должностных лиц. Действия инсайдеров, такие как покупки и продажи акций, нередко рассматриваются как индикаторы доверия к компании, что может оказывать значительное влияние на рыночные ожидания и прогнозы.\n",
"\n",
"**Объекты наблюдений:**\n",
"Объектами наблюдений в датасете являются инсайдеры компании Tesla — лица, имеющие значительное влияние на управление и информацию компании. Каждый объект характеризуется различными параметрами, включая должность, тип транзакции, количество акций и общую стоимость сделок.\n",
"\n",
"**Атрибуты объектов:**\n",
"- Insider Trading: ФИО лица, совершившего транзакцию.\n",
"- Relationship: Должность или статус данного лица в компании Tesla.\n",
"- Date: Дата завершения транзакции.\n",
"- Transaction: Тип транзакции.\n",
"- Cost: Цена одной акции на момент совершения транзакции.\n",
"- Shares: Количество акций, участвующих в транзакции.\n",
"- Value ($): Общая стоимость транзакции в долларах США.\n",
"- Shares Total: Общее количество акций, принадлежащих этому лицу после завершения данной транзакции.\n",
"- SEC Form 4: Дата записи транзакции в форме SEC Form 4, обязательной для отчётности о сделках инсайдеров.\n",
"\n",
"---\n",
"\n",
"### Бизнес-цели:\n",
"1. **Для решения задачи регрессии:**\n",
2024-11-17 18:08:11 +04:00
"Предсказать будущую стоимость акций компании Tesla на основе инсайдерских транзакций. Стоимость акций (\"Cost\") зависит от множества факторов, включая объём и тип транзакций, совершаемых инсайдерами. Если выявить зависимости между параметрами транзакций (количество акций, общий объём сделки, должность инсайдера) и стоимостью акций, это может помочь инвесторам принимать обоснованные решения о покупке или продаже.\n",
2024-11-17 05:38:45 +04:00
"2. **Для решения задачи классификации:**\n",
2024-11-17 18:08:11 +04:00
"Классифицировать тип инсайдерской транзакции (продажа акций или исполнение опционов) на основе характеристик сделки. Тип транзакции (\"Transaction\") может быть индикатором доверия инсайдера к текущей рыночной цене или будущей прибыльности компании. Модель, которая предсказывает тип транзакции, может помочь в оценке поведения инсайдеров и выявлении аномалий.\n",
2024-11-17 05:38:45 +04:00
"\n",
"---"
2024-11-16 21:46:44 +04:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Выгрузка данных из файла в DataFrame:"
]
},
{
"cell_type": "code",
2024-11-17 05:38:45 +04:00
"execution_count": 379,
2024-11-16 21:46:44 +04:00
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, Tuple\n",
2024-11-17 05:38:45 +04:00
"from math import ceil\n",
2024-11-16 21:46:44 +04:00
"\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"df: DataFrame = pd.read_csv('..//static//csv//TSLA.csv')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Краткая информация о DataFrame:"
]
},
{
"cell_type": "code",
2024-11-17 05:38:45 +04:00
"execution_count": 380,
2024-11-16 21:46:44 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 156 entries, 0 to 155\n",
"Data columns (total 9 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 Insider Trading 156 non-null object \n",
" 1 Relationship 156 non-null object \n",
" 2 Date 156 non-null object \n",
" 3 Transaction 156 non-null object \n",
" 4 Cost 156 non-null float64\n",
" 5 Shares 156 non-null object \n",
" 6 Value ($) 156 non-null object \n",
" 7 Shares Total 156 non-null object \n",
" 8 SEC Form 4 156 non-null object \n",
"dtypes: float64(1), object(8)\n",
"memory usage: 11.1+ KB\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>count</th>\n",
" <th>mean</th>\n",
" <th>std</th>\n",
" <th>min</th>\n",
" <th>25%</th>\n",
" <th>50%</th>\n",
" <th>75%</th>\n",
" <th>max</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Cost</th>\n",
" <td>156.0</td>\n",
" <td>478.785641</td>\n",
" <td>448.922903</td>\n",
" <td>0.0</td>\n",
" <td>50.5225</td>\n",
" <td>240.225</td>\n",
" <td>934.1075</td>\n",
" <td>1171.04</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" count mean std min 25% 50% 75% max\n",
"Cost 156.0 478.785641 448.922903 0.0 50.5225 240.225 934.1075 1171.04"
]
},
2024-11-17 05:38:45 +04:00
"execution_count": 380,
2024-11-16 21:46:44 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Краткая информация о DataFrame\n",
"df.info()\n",
"\n",
"# Статистическое описание числовых столбцов\n",
"df.describe().transpose()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-11-16 21:59:40 +04:00
"### Предобработка данных:"
2024-11-16 21:46:44 +04:00
]
},
{
"cell_type": "code",
2024-11-17 05:38:45 +04:00
"execution_count": 381,
2024-11-16 21:46:44 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Выборка данных:\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Insider Trading</th>\n",
" <th>Relationship</th>\n",
" <th>Transaction</th>\n",
" <th>Cost</th>\n",
" <th>Shares</th>\n",
" <th>Value ($)</th>\n",
" <th>Shares Total</th>\n",
" <th>Year</th>\n",
" <th>Month</th>\n",
" <th>Day</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Kirkhorn Zachary</td>\n",
" <td>Chief Financial Officer</td>\n",
" <td>Sale</td>\n",
" <td>196.72</td>\n",
" <td>10455</td>\n",
" <td>2056775</td>\n",
" <td>203073</td>\n",
" <td>2022</td>\n",
" <td>3</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Taneja Vaibhav</td>\n",
" <td>Chief Accounting Officer</td>\n",
" <td>Sale</td>\n",
" <td>195.79</td>\n",
" <td>2466</td>\n",
" <td>482718</td>\n",
" <td>100458</td>\n",
" <td>2022</td>\n",
" <td>3</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Baglino Andrew D</td>\n",
" <td>SVP Powertrain and Energy Eng.</td>\n",
" <td>Sale</td>\n",
" <td>195.79</td>\n",
" <td>1298</td>\n",
" <td>254232</td>\n",
" <td>65547</td>\n",
" <td>2022</td>\n",
" <td>3</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Taneja Vaibhav</td>\n",
" <td>Chief Accounting Officer</td>\n",
" <td>Option Exercise</td>\n",
" <td>0.00</td>\n",
" <td>7138</td>\n",
" <td>0</td>\n",
" <td>102923</td>\n",
" <td>2022</td>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Baglino Andrew D</td>\n",
" <td>SVP Powertrain and Energy Eng.</td>\n",
" <td>Option Exercise</td>\n",
" <td>0.00</td>\n",
" <td>2586</td>\n",
" <td>0</td>\n",
" <td>66845</td>\n",
" <td>2022</td>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>Kirkhorn Zachary</td>\n",
" <td>Chief Financial Officer</td>\n",
" <td>Option Exercise</td>\n",
" <td>0.00</td>\n",
" <td>16867</td>\n",
" <td>0</td>\n",
" <td>213528</td>\n",
" <td>2022</td>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>Baglino Andrew D</td>\n",
" <td>SVP Powertrain and Energy Eng.</td>\n",
" <td>Option Exercise</td>\n",
" <td>20.91</td>\n",
" <td>10500</td>\n",
" <td>219555</td>\n",
" <td>74759</td>\n",
" <td>2022</td>\n",
" <td>2</td>\n",
" <td>27</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>Baglino Andrew D</td>\n",
" <td>SVP Powertrain and Energy Eng.</td>\n",
" <td>Sale</td>\n",
" <td>202.00</td>\n",
" <td>10500</td>\n",
" <td>2121000</td>\n",
" <td>64259</td>\n",
" <td>2022</td>\n",
" <td>2</td>\n",
" <td>27</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>Kirkhorn Zachary</td>\n",
" <td>Chief Financial Officer</td>\n",
" <td>Sale</td>\n",
" <td>193.00</td>\n",
" <td>3750</td>\n",
" <td>723750</td>\n",
" <td>196661</td>\n",
" <td>2022</td>\n",
" <td>2</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>Baglino Andrew D</td>\n",
" <td>SVP Powertrain and Energy Eng.</td>\n",
" <td>Option Exercise</td>\n",
" <td>20.91</td>\n",
" <td>10500</td>\n",
" <td>219555</td>\n",
" <td>74759</td>\n",
" <td>2022</td>\n",
" <td>1</td>\n",
" <td>27</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Insider Trading Relationship Transaction Cost \\\n",
"0 Kirkhorn Zachary Chief Financial Officer Sale 196.72 \n",
"1 Taneja Vaibhav Chief Accounting Officer Sale 195.79 \n",
"2 Baglino Andrew D SVP Powertrain and Energy Eng. Sale 195.79 \n",
"3 Taneja Vaibhav Chief Accounting Officer Option Exercise 0.00 \n",
"4 Baglino Andrew D SVP Powertrain and Energy Eng. Option Exercise 0.00 \n",
"5 Kirkhorn Zachary Chief Financial Officer Option Exercise 0.00 \n",
"6 Baglino Andrew D SVP Powertrain and Energy Eng. Option Exercise 20.91 \n",
"7 Baglino Andrew D SVP Powertrain and Energy Eng. Sale 202.00 \n",
"8 Kirkhorn Zachary Chief Financial Officer Sale 193.00 \n",
"9 Baglino Andrew D SVP Powertrain and Energy Eng. Option Exercise 20.91 \n",
"\n",
" Shares Value ($) Shares Total Year Month Day \n",
"0 10455 2056775 203073 2022 3 6 \n",
"1 2466 482718 100458 2022 3 6 \n",
"2 1298 254232 65547 2022 3 6 \n",
"3 7138 0 102923 2022 3 5 \n",
"4 2586 0 66845 2022 3 5 \n",
"5 16867 0 213528 2022 3 5 \n",
"6 10500 219555 74759 2022 2 27 \n",
"7 10500 2121000 64259 2022 2 27 \n",
"8 3750 723750 196661 2022 2 6 \n",
"9 10500 219555 74759 2022 1 27 "
]
},
2024-11-17 05:38:45 +04:00
"execution_count": 381,
2024-11-16 21:46:44 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Преобразование типов данных\n",
"df['Insider Trading'] = df['Insider Trading'].astype('category') # Преобразование в категорию\n",
"df['Relationship'] = df['Relationship'].astype('category') # Преобразование в категорию\n",
"df['Transaction'] = df['Transaction'].astype('category') # Преобразование в категорию\n",
"df['Cost'] = pd.to_numeric(df['Cost'], errors='coerce') # Преобразование в float\n",
"df['Shares'] = pd.to_numeric(df['Shares'].str.replace(',', ''), errors='coerce') # Преобразование в float с удалением запятых\n",
"df['Value ($)'] = pd.to_numeric(df['Value ($)'].str.replace(',', ''), errors='coerce') # Преобразование в float с удалением запятых\n",
"df['Shares Total'] = pd.to_numeric(df['Shares Total'].str.replace(',', ''), errors='coerce') # Преобразование в float с удалением запятых\n",
"\n",
"df['Date'] = pd.to_datetime(df['Date'], errors='coerce') # Преобразование в datetime\n",
"df['Year'] = df['Date'].dt.year # Год\n",
"df['Month'] = df['Date'].dt.month # Месяц\n",
"df['Day'] = df['Date'].dt.day # День\n",
"df: DataFrame = df.drop(columns=['Date', 'SEC Form 4']) # Удаление столбцов с датами\n",
"\n",
"print('Выборка данных:')\n",
"df.head(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Бизнес-цель №1 (Задача регрессии).\n",
"\n",
"### Достижимый уровень качества модели:\n",
"**Основные метрики для регрессии:**\n",
"- **Средняя абсолютная ошибка (Mean Absolute Error, MAE)** – показывает среднее абсолютное отклонение между предсказанными и фактическими значениями.\n",
"Легко интерпретируется, особенно в финансовых данных, где каждая ошибка в долларах имеет значение.\n",
"- **Среднеквадратичная ошибка (Mean Squared Error, MSE)** – показывает, насколько отклоняются прогнозы модели от истинных значений в квадрате. Подходит для оценки общего качества модели.\n",
"- **Коэффициент детерминации (R²)** – указывает, какую долю дисперсии зависимой переменной объясняет модель. R² варьируется от 0 до 1 (чем ближе к 1, тем лучше).\n",
"\n",
"---\n",
"\n",
"### Выбор ориентира:\n",
"В качестве базовой модели для оценки качества предсказаний выбрано использование среднего значения целевой переменной (Cost) на обучающей выборке. Это простой и интуитивно понятный метод, который служит минимальным ориентиром для сравнения с более сложными моделями. Базовая модель помогает установить начальный уровень ошибок (MAE, MSE) и показатель качества (R²), которые сложные модели должны улучшить, чтобы оправдать своё использование.\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Разбиение данных:"
]
},
{
"cell_type": "code",
2024-11-17 05:38:45 +04:00
"execution_count": 382,
2024-11-16 21:46:44 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Baseline MAE: 417.78235887096776\n",
"Baseline MSE: 182476.07973024843\n",
"Baseline R²: -0.027074997920953914\n"
]
}
],
"source": [
"from pandas.core.frame import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
"\n",
"\n",
"# Разбить данные на обучающую и тестовую выборки\n",
"def split_into_train_test(\n",
" df_input: DataFrame,\n",
" stratify_colname: str = \"y\", \n",
" frac_train: float = 0.8,\n",
" random_state: int = 42,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
"\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" \n",
" if not (0 < frac_train < 1):\n",
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
" \n",
" X: DataFrame = df_input # Contains all columns.\n",
" y: DataFrame = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
"\n",
" # Split original dataframe into train and test dataframes.\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y,\n",
" test_size=(1.0 - frac_train),\n",
" random_state=random_state\n",
" )\n",
" \n",
" return X_train, X_test, y_train, y_test\n",
"\n",
"\n",
"# Определяем целевой признак и входные признаки\n",
"y_feature: str = 'Cost'\n",
"X_features: list[str] = df.drop(columns=y_feature, axis=1).columns.tolist()\n",
"\n",
"# Разбиваем данные на обучающую и тестовую выборки\n",
"X_df_train, X_df_test, y_df_train, y_df_test = split_into_train_test(\n",
" df, \n",
" stratify_colname=y_feature, \n",
" frac_train=0.8, \n",
" random_state=42 \n",
")\n",
"\n",
"# Вычисляем предсказания базовой модели (среднее значение целевой переменной)\n",
"baseline_predictions: list[float] = [y_df_train.mean()] * len(y_df_test) # type: ignore\n",
"\n",
"# Оцениваем базовую модель\n",
"print('Baseline MAE:', mean_absolute_error(y_df_test, baseline_predictions))\n",
"print('Baseline MSE:', mean_squared_error(y_df_test, baseline_predictions))\n",
"print('Baseline R²:', r2_score(y_df_test, baseline_predictions))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Выбор моделей обучения:\n",
"\n",
"Для обучения были выбраны следующие модели:\n",
"1. **Случайный лес (Random Forest)**: Ансамблевая модель, которая использует множество решающих деревьев. Она хорошо справляется с нелинейными зависимостями и шумом в данных, а также обладает устойчивостью к переобучению.\n",
"2. **Линейная регрессия (Linear Regression)**: Простая модель, предполагающая линейную зависимость между признаками и целевой переменной. Она быстро обучается и предоставляет легкую интерпретацию результатов.\n",
"3. **Градиентный бустинг (Gradient Boosting)**: Мощная модель, создающая ансамбль деревьев, которые корректируют ошибки предыдущих. Эта модель эффективна для сложных наборов данных и обеспечивает высокую точность предсказаний.\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Построение конвейера:"
]
},
{
"cell_type": "code",
2024-11-17 05:38:45 +04:00
"execution_count": 383,
2024-11-16 21:46:44 +04:00
"metadata": {},
"outputs": [],
"source": [
"from sklearn.impute import SimpleImputer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.pipeline import Pipeline\n",
"\n",
"\n",
"# Числовые столбцы\n",
"num_columns: list[str] = [\n",
" column\n",
" for column in df.columns\n",
" if df[column].dtype not in (\"category\", \"object\")\n",
"]\n",
"\n",
"# Категориальные столбцы\n",
"cat_columns: list[str] = [\n",
" column\n",
" for column in df.columns\n",
" if df[column].dtype in (\"category\", \"object\")\n",
"]\n",
"\n",
"# Заполнение пропущенных значений\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"# Стандартизация\n",
"num_scaler = StandardScaler()\n",
"# Конвейер для обработки числовых данных\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"# Заполнение пропущенных значений\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"# Унитарное кодирование\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"# Конвейер для обработки категориальных данных\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"# Трансформер для предобработки признаков\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"# Основной конвейер предобработки данных\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Демонстрация работы конвейера:"
]
},
{
"cell_type": "code",
2024-11-17 05:38:45 +04:00
"execution_count": 384,
2024-11-16 21:46:44 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Cost</th>\n",
" <th>Shares</th>\n",
" <th>Value ($)</th>\n",
" <th>Shares Total</th>\n",
" <th>Year</th>\n",
" <th>Month</th>\n",
" <th>Day</th>\n",
" <th>Insider Trading_DENHOLM ROBYN M</th>\n",
" <th>Insider Trading_Kirkhorn Zachary</th>\n",
" <th>Insider Trading_Musk Elon</th>\n",
" <th>Insider Trading_Musk Kimbal</th>\n",
" <th>Insider Trading_Taneja Vaibhav</th>\n",
" <th>Insider Trading_Wilson-Thompson Kathleen</th>\n",
" <th>Relationship_Chief Accounting Officer</th>\n",
" <th>Relationship_Chief Financial Officer</th>\n",
" <th>Relationship_Director</th>\n",
" <th>Relationship_SVP Powertrain and Energy Eng.</th>\n",
" <th>Transaction_Sale</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-0.966516</td>\n",
" <td>-0.361759</td>\n",
" <td>-0.450022</td>\n",
" <td>-0.343599</td>\n",
" <td>0.715678</td>\n",
" <td>-0.506108</td>\n",
" <td>-0.400623</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>-1.074894</td>\n",
" <td>1.225216</td>\n",
" <td>-0.414725</td>\n",
" <td>-0.319938</td>\n",
" <td>-1.397276</td>\n",
" <td>0.801338</td>\n",
" <td>0.906673</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>-1.074894</td>\n",
" <td>1.211753</td>\n",
" <td>-0.415027</td>\n",
" <td>-0.320141</td>\n",
" <td>-1.397276</td>\n",
" <td>1.062828</td>\n",
" <td>-0.098939</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.167142</td>\n",
" <td>0.037499</td>\n",
" <td>1.023612</td>\n",
" <td>-0.325853</td>\n",
" <td>-1.397276</td>\n",
" <td>1.062828</td>\n",
" <td>-0.501184</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1.217886</td>\n",
" <td>-0.075287</td>\n",
" <td>0.632973</td>\n",
" <td>-0.330205</td>\n",
" <td>-1.397276</td>\n",
" <td>1.062828</td>\n",
" <td>-0.501184</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0.505872</td>\n",
" <td>-0.361021</td>\n",
" <td>-0.443679</td>\n",
" <td>-0.343698</td>\n",
" <td>0.715678</td>\n",
" <td>-0.767598</td>\n",
" <td>1.308918</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>-1.088674</td>\n",
" <td>-0.357532</td>\n",
" <td>-0.450389</td>\n",
" <td>-0.342863</td>\n",
" <td>0.715678</td>\n",
" <td>0.278360</td>\n",
" <td>-0.903429</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>-0.692146</td>\n",
" <td>-0.355855</td>\n",
" <td>-0.445383</td>\n",
" <td>-0.343220</td>\n",
" <td>0.715678</td>\n",
" <td>0.801338</td>\n",
" <td>1.409480</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>-1.088674</td>\n",
" <td>-0.361181</td>\n",
" <td>-0.450389</td>\n",
" <td>-0.343649</td>\n",
" <td>-1.397276</td>\n",
" <td>1.062828</td>\n",
" <td>-0.903429</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>1.091997</td>\n",
" <td>-0.204531</td>\n",
" <td>0.114712</td>\n",
" <td>1.538166</td>\n",
" <td>0.715678</td>\n",
" <td>-1.029087</td>\n",
" <td>1.208357</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Cost Shares Value ($) Shares Total Year Month Day \\\n",
"0 -0.966516 -0.361759 -0.450022 -0.343599 0.715678 -0.506108 -0.400623 \n",
"1 -1.074894 1.225216 -0.414725 -0.319938 -1.397276 0.801338 0.906673 \n",
"2 -1.074894 1.211753 -0.415027 -0.320141 -1.397276 1.062828 -0.098939 \n",
"3 1.167142 0.037499 1.023612 -0.325853 -1.397276 1.062828 -0.501184 \n",
"4 1.217886 -0.075287 0.632973 -0.330205 -1.397276 1.062828 -0.501184 \n",
"5 0.505872 -0.361021 -0.443679 -0.343698 0.715678 -0.767598 1.308918 \n",
"6 -1.088674 -0.357532 -0.450389 -0.342863 0.715678 0.278360 -0.903429 \n",
"7 -0.692146 -0.355855 -0.445383 -0.343220 0.715678 0.801338 1.409480 \n",
"8 -1.088674 -0.361181 -0.450389 -0.343649 -1.397276 1.062828 -0.903429 \n",
"9 1.091997 -0.204531 0.114712 1.538166 0.715678 -1.029087 1.208357 \n",
"\n",
" Insider Trading_DENHOLM ROBYN M Insider Trading_Kirkhorn Zachary \\\n",
"0 0.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 0.0 \n",
"4 0.0 0.0 \n",
"5 0.0 0.0 \n",
"6 0.0 0.0 \n",
"7 0.0 0.0 \n",
"8 0.0 0.0 \n",
"9 0.0 0.0 \n",
"\n",
" Insider Trading_Musk Elon Insider Trading_Musk Kimbal \\\n",
"0 0.0 0.0 \n",
"1 1.0 0.0 \n",
"2 1.0 0.0 \n",
"3 1.0 0.0 \n",
"4 1.0 0.0 \n",
"5 0.0 0.0 \n",
"6 0.0 0.0 \n",
"7 0.0 0.0 \n",
"8 0.0 0.0 \n",
"9 1.0 0.0 \n",
"\n",
" Insider Trading_Taneja Vaibhav Insider Trading_Wilson-Thompson Kathleen \\\n",
"0 1.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 0.0 \n",
"4 0.0 0.0 \n",
"5 0.0 0.0 \n",
"6 1.0 0.0 \n",
"7 0.0 0.0 \n",
"8 1.0 0.0 \n",
"9 0.0 0.0 \n",
"\n",
" Relationship_Chief Accounting Officer \\\n",
"0 1.0 \n",
"1 0.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"4 0.0 \n",
"5 0.0 \n",
"6 1.0 \n",
"7 0.0 \n",
"8 1.0 \n",
"9 0.0 \n",
"\n",
" Relationship_Chief Financial Officer Relationship_Director \\\n",
"0 0.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 0.0 \n",
"4 0.0 0.0 \n",
"5 0.0 0.0 \n",
"6 0.0 0.0 \n",
"7 0.0 0.0 \n",
"8 0.0 0.0 \n",
"9 0.0 0.0 \n",
"\n",
" Relationship_SVP Powertrain and Energy Eng. Transaction_Sale \n",
"0 0.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 1.0 \n",
"4 0.0 1.0 \n",
"5 1.0 1.0 \n",
"6 0.0 0.0 \n",
"7 1.0 1.0 \n",
"8 0.0 0.0 \n",
"9 0.0 1.0 "
]
},
2024-11-17 05:38:45 +04:00
"execution_count": 384,
2024-11-16 21:46:44 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Применение конвейера\n",
"preprocessing_result = pipeline_end.fit_transform(X_df_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df.head(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Обучение моделей:\n",
"\n",
"Оценка результатов обучения:\n",
"1. **Случайный лес (Random Forest)**:\n",
" - Показатели:\n",
" - Средний балл: 0.9993.\n",
" - Стандартное отклонение: 0.00046.\n",
" - Вывод: Очень высокая точность, что свидетельствует о хорошей способности модели к обобщению. Низкое значение стандартного отклонения указывает на стабильность модели.\n",
"2. **Линейная регрессия (Linear Regression)**:\n",
" - Показатели:\n",
" - Средний балл: 1.0.\n",
" - Стандартное отклонение: 0.0.\n",
" - Вывод: Идеальная точность, однако есть вероятность переобучения, так как стандартное отклонение равно 0. Это может указывать на то, что модель идеально подгоняет данные, но может не работать на новых данных.\n",
"3. **Градиентный бустинг (Gradient Boosting)**:\n",
" - Показатели:\n",
" - Средний балл: 0.9998.\n",
" - Стандартное отклонение: 0.00014.\n",
" - Вывод: Отличные результаты с высокой точностью и низкой вариабельностью. Модель также демонстрирует хорошую устойчивость."
]
},
{
"cell_type": "code",
2024-11-17 05:38:45 +04:00
"execution_count": 385,
2024-11-16 21:46:44 +04:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-11-17 05:38:45 +04:00
"Модель: Random Forest\n",
"\tmean_score: 0.9992580181099008\n",
"\tstd_dev: 0.0004834744839371662\n",
"\n",
"Модель: Linear Regression\n",
"\tmean_score: 1.0\n",
"\tstd_dev: 0.0\n",
"\n",
"Модель: Gradient Boosting\n",
"\tmean_score: 0.9997687065029746\n",
"\tstd_dev: 0.00014193622424523165\n",
"\n"
2024-11-16 21:46:44 +04:00
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor\n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.model_selection import cross_val_score\n",
"\n",
"\n",
"# Обучить модели\n",
"def train_models(X: DataFrame, y: DataFrame, \n",
2024-11-17 05:38:45 +04:00
" models: dict[str, Any]) -> dict[str, dict[str, Any]]:\n",
" results: dict[str, dict[str, Any]] = {}\n",
" \n",
2024-11-16 21:46:44 +04:00
" for model_name, model in models.items():\n",
2024-11-17 05:38:45 +04:00
" # Создание конвейера для текущей модели\n",
2024-11-16 21:46:44 +04:00
" model_pipeline = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
2024-11-17 05:38:45 +04:00
" (\"model\", model)\n",
2024-11-16 21:46:44 +04:00
" ]\n",
" )\n",
" \n",
" # Обучаем модель и вычисляем кросс-валидацию\n",
" scores = cross_val_score(model_pipeline, X, y, cv=5) # 5-кратная кросс-валидация\n",
2024-11-17 05:38:45 +04:00
" \n",
" # Вычисление метрик для текущей модели\n",
" metrics_dict: dict[str, Any] = {\n",
2024-11-16 21:46:44 +04:00
" \"mean_score\": scores.mean(),\n",
" \"std_dev\": scores.std()\n",
" }\n",
2024-11-17 05:38:45 +04:00
" \n",
" # Сохранениерезультатов\n",
" results[model_name] = metrics_dict\n",
2024-11-16 21:46:44 +04:00
" \n",
" return results\n",
"\n",
"\n",
2024-11-17 05:38:45 +04:00
"# Выбранные модели для регрессии\n",
2024-11-16 21:46:44 +04:00
"models_regression: dict[str, Any] = {\n",
" \"Random Forest\": RandomForestRegressor(),\n",
" \"Linear Regression\": LinearRegression(),\n",
" \"Gradient Boosting\": GradientBoostingRegressor(),\n",
"}\n",
"\n",
"results: dict[str, Any] = train_models(X_df_train, y_df_train, models_regression)\n",
"\n",
"# Вывод результатов\n",
2024-11-17 05:38:45 +04:00
"for model_name, metrics_dict in results.items():\n",
" print(f\"Модель: {model_name}\")\n",
" for metric_name, value in metrics_dict.items():\n",
" print(f\"\\t{metric_name}: {value}\")\n",
" print()"
2024-11-16 21:46:44 +04:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Проверка на тестовом наборе данных:\n",
"\n",
"Оценка результатов обучения:\n",
"1. Случайный лес (Random Forest):\n",
" - Показатели:\n",
" - MAE (обучение): 1.858\n",
" - MAE (тест): 4.489\n",
" - MSE (обучение): 10.959\n",
" - MSE (тест): 62.649\n",
2024-11-17 05:38:45 +04:00
" - R² (обучение): 0.9999\n",
" - R² (тест): 0.9997\n",
2024-11-16 21:46:44 +04:00
" - STD (обучение): 3.310\n",
" - STD (тест): 7.757\n",
" - Вывод: Случайный лес показывает великолепные значения R2 на обучающей и тестовой выборках, что свидетельствует о сильной способности к обобщению. Однако MAE и MSE на тестовой выборке значительно выше, чем на обучающей, что может указывать на некоторые проблемы с переобучением.\n",
"2. Линейная регрессия (Linear Regression):\n",
" - Показатели:\n",
" - MAE (обучение): 3.069e-13\n",
" - MAE (тест): 2.762e-13\n",
" - MSE (обучение): 1.437e-25\n",
" - MSE (тест): 1.196e-25\n",
2024-11-17 05:38:45 +04:00
" - R² (обучение): 1.0\n",
" - R² (тест): 1.0\n",
2024-11-16 21:46:44 +04:00
" - STD (обучение): 3.730e-13\n",
" - STD (тест): 3.444e-13\n",
" - Вывод: Высокие показатели точности и нулевые ошибки (MAE, MSE) указывают на то, что модель идеально подгоняет данные как на обучающей, так и на тестовой выборках. Однако это также может быть признаком переобучения.\n",
"3. Градиентный бустинг (Gradient Boosting):\n",
" - Показатели:\n",
" - MAE (обучение): 0.156\n",
" - MAE (тест): 3.027\n",
" - MSE (обучение): 0.075\n",
" - MSE (тест): 41.360\n",
2024-11-17 05:38:45 +04:00
" - R² (обучение): 0.9999996\n",
" - R² (тест): 0.9998\n",
2024-11-16 21:46:44 +04:00
" - STD (обучение): 0.274\n",
" - STD (тест): 6.399\n",
" - Вывод: Градиентный бустинг демонстрирует отличные результаты на обучающей выборке, однако MAE и MSE на тестовой выборке довольно высокие, что может указывать на определенное переобучение или необходимость улучшения настройки модели."
]
},
{
"cell_type": "code",
2024-11-17 05:38:45 +04:00
"execution_count": 386,
2024-11-16 21:46:44 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Модель: Random Forest\n",
2024-11-17 05:38:45 +04:00
"\tMAE_train: 1.955516935483828\n",
"\tMAE_test: 4.46537187499996\n",
"\tMSE_train: 11.287871282983637\n",
"\tMSE_test: 66.47081479843644\n",
"\tR2_train: 0.9999449583585838\n",
"\tR2_test: 0.9996258659651619\n",
"\tSTD_train: 3.351830348079478\n",
"\tSTD_test: 8.067958792345765\n",
2024-11-16 21:46:44 +04:00
"\n",
"Модель: Linear Regression\n",
"\tMAE_train: 3.0690862038154006e-13\n",
"\tMAE_test: 2.761679773755077e-13\n",
"\tMSE_train: 1.4370485712253764e-25\n",
"\tMSE_test: 1.19585889812782e-25\n",
"\tR2_train: 1.0\n",
"\tR2_test: 1.0\n",
"\tSTD_train: 3.7295840825107354e-13\n",
"\tSTD_test: 3.4438670391637766e-13\n",
"\n",
"Модель: Gradient Boosting\n",
2024-11-17 05:38:45 +04:00
"\tMAE_train: 0.15613772760448247\n",
"\tMAE_test: 2.9760510050502877\n",
"\tMSE_train: 0.07499640211231862\n",
"\tMSE_test: 38.91708171007616\n",
2024-11-16 21:46:44 +04:00
"\tR2_train: 0.9999996343043813\n",
2024-11-17 05:38:45 +04:00
"\tR2_test: 0.9997809534176997\n",
"\tSTD_train: 0.2738547098596601\n",
"\tSTD_test: 6.197132274535746\n",
2024-11-16 21:46:44 +04:00
"\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"from sklearn import metrics\n",
"\n",
"\n",
"# Оценка качества различных моделей на основе метрик\n",
2024-11-16 21:59:40 +04:00
"def evaluate_models(models: dict[str, Any], \n",
" pipeline_end: Pipeline, \n",
" X_train: DataFrame, y_train, \n",
" X_test: DataFrame, y_test) -> dict[str, dict[str, Any]]:\n",
2024-11-16 21:46:44 +04:00
" results: dict[str, dict[str, Any]] = {}\n",
" \n",
" for model_name, model in models.items():\n",
2024-11-17 05:38:45 +04:00
" # Создание конвейера для текущей модели\n",
2024-11-16 21:46:44 +04:00
" model_pipeline = Pipeline(\n",
" [\n",
" (\"pipeline\", pipeline_end), \n",
" (\"model\", model),\n",
" ]\n",
" )\n",
" \n",
" # Обучение текущей модели\n",
" model_pipeline.fit(X_train, y_train)\n",
"\n",
" # Предсказание для обучающей и тестовой выборки\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_predict = model_pipeline.predict(X_test)\n",
"\n",
" # Вычисление метрик для текущей модели\n",
" metrics_dict: dict[str, Any] = {\n",
" \"MAE_train\": metrics.mean_absolute_error(y_train, y_train_predict),\n",
" \"MAE_test\": metrics.mean_absolute_error(y_test, y_test_predict),\n",
" \"MSE_train\": metrics.mean_squared_error(y_train, y_train_predict),\n",
" \"MSE_test\": metrics.mean_squared_error(y_test, y_test_predict),\n",
" \"R2_train\": metrics.r2_score(y_train, y_train_predict),\n",
" \"R2_test\": metrics.r2_score(y_test, y_test_predict),\n",
" \"STD_train\": np.std(y_train - y_train_predict),\n",
" \"STD_test\": np.std(y_test - y_test_predict),\n",
" }\n",
"\n",
" # Сохранение результатов\n",
" results[model_name] = metrics_dict\n",
" \n",
" return results\n",
"\n",
"\n",
"y_train = np.ravel(y_df_train) \n",
"y_test = np.ravel(y_df_test) \n",
"\n",
2024-11-17 05:38:45 +04:00
"results: dict[str, dict[str, Any]] = evaluate_models(models_regression,\n",
" pipeline_end,\n",
" X_df_train, y_train,\n",
" X_df_test, y_test)\n",
2024-11-16 21:46:44 +04:00
"\n",
"# Вывод результатов\n",
2024-11-17 05:38:45 +04:00
"for model_name, metrics_dict in results.items():\n",
2024-11-16 21:46:44 +04:00
" print(f\"Модель: {model_name}\")\n",
" for metric_name, value in metrics_dict.items():\n",
" print(f\"\\t{metric_name}: {value}\")\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Подбор гиперпараметров:"
]
},
{
"cell_type": "code",
2024-11-17 05:38:45 +04:00
"execution_count": 387,
2024-11-16 21:46:44 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 36 candidates, totalling 108 fits\n",
2024-11-17 05:38:45 +04:00
"Лучшие параметры: {'max_depth': 10, 'min_samples_split': 5, 'n_estimators': 50}\n",
"Лучший результат (MSE): 196.9489804872991\n"
2024-11-16 21:46:44 +04:00
]
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"\n",
"# Применение конвейера к данным\n",
"X_train_processing_result = pipeline_end.fit_transform(X_df_train)\n",
"X_test_processing_result = pipeline_end.transform(X_df_test)\n",
"\n",
"# Создание и настройка модели случайного леса\n",
"model = RandomForestRegressor()\n",
"\n",
"# Установка параметров для поиска по сетке\n",
"param_grid: dict[str, list[int | None]] = {\n",
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
" 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью поиска по сетке\n",
"grid_search = GridSearchCV(estimator=model, \n",
" param_grid=param_grid,\n",
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"grid_search.fit(X_train_processing_result, y_train)\n",
"\n",
"# Результаты подбора гиперпараметров\n",
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
"# Меняем знак, так как берем отрицательное значение среднеквадратичной ошибки\n",
"print(\"Лучший результат (MSE):\", -grid_search.best_score_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Сравнение наборов гиперпараметров:\n",
"\n",
"Результаты анализа показывают, что параметры из старой сетки обеспечивают значительно лучшее качество модели. Среднеквадратическая ошибка (MSE) на кросс-валидации для старых параметров составила 179.369, что существенно ниже, чем для новых параметров (1290.656). Н а тестовой выборке модель с новыми параметрами показала MSE 172.574, что сопоставимо с результатами модели с о старыми параметрами, однако этот результат является случайным, так как новые параметры продемонстрировали плохую кросс-валидационную ошибку, указывая на недообучение. Таким образом, параметры из старой сетки более предпочтительны, так как они обеспечивают лучшее обобщение и меньшую ошибку."
]
},
{
"cell_type": "code",
2024-11-17 05:38:45 +04:00
"execution_count": 388,
2024-11-16 21:46:44 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 36 candidates, totalling 108 fits\n",
2024-11-17 05:38:45 +04:00
"Старые параметры: {'max_depth': 10, 'min_samples_split': 5, 'n_estimators': 50}\n",
"Лучший результат (MSE) на старых параметрах: 184.14248778487732\n",
2024-11-16 21:46:44 +04:00
"\n",
"Новые параметры: {'max_depth': 5, 'min_samples_split': 10, 'n_estimators': 50}\n",
2024-11-17 05:38:45 +04:00
"Лучший результат (MSE) на новых параметрах: 1283.4356458868208\n",
"Среднеквадратическая ошибка (MSE) на тестовых данных: 159.03284823315155\n",
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 12.610822662822262\n"
2024-11-16 21:46:44 +04:00
]
},
{
"data": {
2024-11-17 05:38:45 +04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1sAAAHWCAYAAACBjZMqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3hTZfvA8W+SNk2b7j1pgbKHyBDZoMhwIjhBBQf4U1TciggiqCjgAtfrqy84cKLiHigiIAjIlj0KpXsl3SPj/P44NFI6aEvbpO39uS4uTc5zzrmbJum5z/M896NRFEVBCCGEEEIIIUSD0jo7ACGEEEIIIYRoiSTZEkIIIYQQQohGIMmWEEIIIYQQQjQCSbaEEEIIIYQQohFIsiWEEEIIIYQQjUCSLSGEEEIIIYRoBJJsCSGEEEIIIUQjkGRLCCGEEEIIIRqBJFtCCCGEEEII0Qgk2RJCCCGEEI3uu+++Y+fOnY7Hq1atYu/evc4LSIgmIMmWEK3A0aNHufPOO2nXrh0GgwFfX18GDRrEq6++SnFxsbPDE0II0Qrs2bOHGTNmcPjwYf766y/+7//+j/z8fGeHJUSj0iiKojg7CCFE4/n++++59tpr8fDw4JZbbqF79+6UlZWxYcMGvvjiC6ZMmcLbb7/t7DCFEEK0cJmZmQwcOJAjR44AMH78eL744gsnRyVE45JkS4gWLCEhgZ49exIdHc2aNWuIiIiosP3IkSN8//33zJgxw0kRCiGEaE1KS0v5559/8PLyokuXLs4OR4hGJ8MIhWjBFi5cSEFBAe+++26lRAsgPj6+QqKl0Wi45557WLFiBZ06dcJgMNCnTx/WrVtXYb8TJ05w991306lTJzw9PQkKCuLaa6/l+PHjFdotX74cjUbj+Ofl5UWPHj145513KrSbMmUK3t7eleJbuXIlGo2GtWvXVnh+8+bNjBkzBj8/P7y8vBg2bBh//vlnhTZz585Fo9GQlZVV4fm///4bjUbD8uXLK5w/Li6uQruTJ0/i6emJRqOp9HP9+OOPDBkyBKPRiI+PD5dddlmt5h2Uvx7r1q3jzjvvJCgoCF9fX2655RZMJlOl9rU5z+7du5kyZYpjiGh4eDi33XYb2dnZVcYQFxdX4XdS/u/01zguLo7LL7+8xp/l+PHjaDQaFi9eXGlb9+7dGT58uOPx2rVr0Wg0rFy5strjnfk7eOqpp9Bqtfz2228V2k2bNg29Xs+uXbtqjE+j0TB37twKzy1atAiNRlMhtpr2r+7f6XGe/jq8/PLLxMbG4unpybBhw/jnn38qHffAgQNcc801BAYGYjAY6Nu3L998802VMUyZMqXK80+ZMqVS2x9//JFhw4bh4+ODr68v/fr146OPPnJsHz58eKWf+9lnn0Wr1VZot379eq699lratGmDh4cHMTExPPDAA5WGG8+dO5euXbvi7e2Nr68vF154IatWrarQprbHqsvnf/jw4XTv3r1S28WLF1f6rJ7tfVz+viw//v79+/H09OSWW26p0G7Dhg3odDoee+yxao8FtXtN6hL/119/zWWXXUZkZCQeHh60b9+e+fPnY7PZKuxb1Xu9/LumPt9ddf19nPm+2rp1q+O9WlWcHh4e9OnThy5dutTpMylEc+Xm7ACEEI3n22+/pV27dgwcOLDW+/zxxx98+umn3HfffXh4ePDGG28wZswYtmzZ4rhI2Lp1Kxs3buSGG24gOjqa48eP8+abbzJ8+HD27duHl5dXhWO+/PLLBAcHk5eXx//+9z+mTp1KXFwcI0eOrPPPtGbNGsaOHUufPn0cF+TLli3joosuYv369VxwwQV1PmZV5syZQ0lJSaXnP/jgAyZPnszo0aN54YUXKCoq4s0332Tw4MHs2LGjUtJWlXvuuQd/f3/mzp3LwYMHefPNNzlx4oTj4q8u51m9ejXHjh3j1ltvJTw8nL179/L222+zd+9e/vrrr0oXPABDhgxh2rRpgHqB+dxzz9X/hWokTz75JN9++y233347e/bswcfHh59//pn//ve/zJ8/n/POO69OxzObzSxYsKBO+1xyySWVLrxffPHFKhPj999/n/z8fKZPn05JSQmvvvoqF110EXv27CEsLAyAvXv3MmjQIKKionj88ccxGo189tlnjBs3ji+++IKrr7660nE9PDwq3Jy44447KrVZvnw5t912G926dWPmzJn4+/uzY8cOfvrpJyZOnFjlz7Zs2TKefPJJXnzxxQptPv/8c4qKirjrrrsICgpiy5YtLF26lKSkJD7//HNHu8LCQq6++mri4uIoLi5m+fLlTJgwgU2bNjk+g7U9lqvo0qUL8+fP55FHHuGaa67hyiuvpLCwkClTptC5c2fmzZtX4/61eU3qYvny5Xh7e/Pggw/i7e3NmjVrmDNnDnl5eSxatKjOx2uI767aOFtSWq4+n0khmiVFCNEi5ebmKoBy1VVX1XofQAGUv//+2/HciRMnFIPBoFx99dWO54qKiirtu2nTJgVQ3n//fcdzy5YtUwAlISHB8dyhQ4cUQFm4cKHjucmTJytGo7HSMT///HMFUH7//XdFURTFbrcrHTp0UEaPHq3Y7fYK8bRt21a55JJLHM899dRTCqBkZmZWOObWrVsVQFm2bFmF88fGxjoe//PPP4pWq1XGjh1bIf78/HzF399fmTp1aoVjpqWlKX5+fpWeP1P569GnTx+lrKzM8fzChQsVQPn666/rfJ6qfhcff/yxAijr1q2rtC0qKkq59dZbHY9///33Cq+xoihKbGysctlll9X4syQkJCiAsmjRokrbunXrpgwbNqzSOT7//PNqj3fm70BRFGXPnj2KXq9X7rjjDsVkMilRUVFK3759FYvFUmNsiqK+l5966inH40cffVQJDQ1V+vTpUyG2mvafPn16pecvu+yyCnGWvw6enp5KUlKS4/nNmzcrgPLAAw84nrv44ouVHj16KCUlJY7n7Ha7MnDgQKVDhw6VzjVx4kTF29u7wnNGo1GZPHmy47HZbFZ8fHyU/v37K8XFxRXanv4ZGTZsmOPn/v777xU3NzfloYceqnTOqt5PCxYsUDQajXLixIlK28plZGQogLJ48eI6H6u2n//yn6Nbt26V2i5atKjSd83Z3sdVvfdtNpsyePBgJSwsTMnKylKmT5+uuLm5KVu3bq32ONWp6jWpS/xVvX533nmn4uXlVeE9pNFolDlz5lRod+Z3b12+U+r6+zj98/TDDz8ogDJmzBjlzEvMc/1MCtFcyTBCIVqovLw8AHx8fOq034ABA+jTp4/jcZs2bbjqqqv4+eefHcNXPD09HdstFgvZ2dnEx8fj7+/P9u3bKx3TZDKRlZXFsWPHePnll9HpdAwbNqxSu6ysrAr/zqxStXPnTg4fPszEiRPJzs52tCssLOTiiy9m3bp12O32Cvvk5ORUOGZubu5ZX4OZM2fSu3dvrr322grPr169GrPZzI033ljhmDqdjv79+/P777+f9digDoVzd3d3PL7rrrtwc3Pjhx9+qPN5Tv9dlJSUkJWVxYUXXghQ5e+irKwMDw+Ps8ZosVjIysoiOzsbq9VabbuioqJKv7czhzmVy8/PJysrC7PZfNbzgzoc8emnn+add95h9OjRZGVl8d577+HmVrdBGcnJySxdupTZs2dXOTyqIYwbN46oqCjH4wsuuID+/fs7fqc5OTmsWbOG6667zvE6lL++o0eP5vDhwyQnJ1c4ZklJCQaDocbzrl69mvz8fB5//PFKbavq1dyyZQvXXXcdEyZMqLJ35PT3U2FhIVlZWQwcOBBFUdixY0eFtuXvkaNHj/L888+j1WoZNGhQvY4FZ//8l7PZbJXaFhUVVdm2tu/jclqtluXLl1NQUMDYsWN54403mDlzJn379j3rvqefr7rXpC7xn/76lb9nhgwZQlFREQcOHHBsCw0NJSkpqca46vPdVdvfRzlFUZg5cyYTJkygf//+NbZtis+kEK5ChhEK0UL5+voC1LmsbocOHSo917FjR4qKisjMzCQ8PJzi4mIWLFjAsmXLSE5ORjmtzk5VyUzv3r0d/+/h4cFrr71WaVhNYWE
2024-11-16 21:46:44 +04:00
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Установка параметров для поиска по сетке для старых значений\n",
"old_param_grid: dict[str, list[int | None]] = {\n",
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
" 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью поиска по сетке для старых параметров\n",
"old_grid_search = GridSearchCV(estimator=model, \n",
" param_grid=old_param_grid,\n",
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"old_grid_search.fit(X_train_processing_result, y_train)\n",
"\n",
"# Результаты подбора для старых параметров\n",
"old_best_params = old_grid_search.best_params_\n",
2024-11-16 21:59:40 +04:00
"# Меняем знак, так как берем отрицательное значение MSE\n",
2024-11-16 21:46:44 +04:00
"old_best_mse = -old_grid_search.best_score_\n",
"\n",
"\n",
"# Установка параметров для поиска по сетке для новых значений\n",
"new_param_grid: dict[str, list[int]] = {\n",
" 'n_estimators': [50],\n",
" 'max_depth': [5],\n",
" 'min_samples_split': [10]\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью поиска по сетке для новых параметров\n",
"new_grid_search = GridSearchCV(estimator=model, \n",
" param_grid=new_param_grid,\n",
" scoring='neg_mean_squared_error', cv=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"new_grid_search.fit(X_train_processing_result, y_train)\n",
"\n",
"# Результаты подбора для новых параметров\n",
"new_best_params = new_grid_search.best_params_\n",
"# Меняем знак, так как берем отрицательное значение MSE\n",
"new_best_mse = -new_grid_search.best_score_\n",
"\n",
2024-11-16 21:59:40 +04:00
"\n",
2024-11-16 21:46:44 +04:00
"# Обучение модели с лучшими параметрами для новых значений\n",
"model_best = RandomForestRegressor(**new_best_params)\n",
"model_best.fit(X_train_processing_result, y_train)\n",
"\n",
"# Прогнозирование на тестовой выборке\n",
"y_pred = model_best.predict(X_test_processing_result)\n",
"\n",
"# Оценка производительности модели\n",
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"\n",
2024-11-16 21:59:40 +04:00
"\n",
2024-11-16 21:46:44 +04:00
"# Вывод результатов\n",
"print(\"Старые параметры:\", old_best_params)\n",
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
"print(\"\\nН о вые параметры:\", new_best_params)\n",
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)\n",
"\n",
"# Обучение модели с лучшими параметрами для старых значений\n",
"model_old = RandomForestRegressor(**old_best_params)\n",
"model_old.fit(X_train_processing_result, y_train)\n",
"\n",
"# Прогнозирование на тестовой выборке для старых параметров\n",
"y_pred_old = model_old.predict(X_test_processing_result)\n",
"\n",
"# Визуализация ошибок\n",
"plt.figure(figsize=(10, 5))\n",
"plt.plot(y_test, label='Реальные значения', marker='o', linestyle='-', color='black')\n",
"plt.plot(y_pred_old, label='Предсказанные значения (старые параметры)', marker='x', linestyle='--', color='blue')\n",
"plt.plot(y_pred, label='Предсказанные значения (новые параметры)', marker='s', linestyle='--', color='orange')\n",
"plt.xlabel('Объекты')\n",
"plt.ylabel('Значения')\n",
"plt.title('Сравнение реальных и предсказанных значений')\n",
"plt.legend()\n",
"plt.show()"
]
2024-11-17 05:38:45 +04:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Бизес-цель №2 (Задача классификации).\n",
"\n",
"### Достижимый уровень качества модели:\n",
"**Основные метрики для классификации:**\n",
"- **Accuracy (точность)** – показывает долю правильно классифицированных примеров среди всех наблюдений. Легко интерпретируется, но может быть недостаточно информативной для несбалансированных классов.\n",
"- **F1-Score** – гармоническое среднее между точностью (precision) и полнотой (recall). Подходит для задач, где важно одновременно учитывать как ложные положительные, так и ложные отрицательные ошибки, особенно при несбалансированных классах.\n",
"- **ROC AUC (Area Under the ROC Curve)** – отражает способность модели различать положительные и отрицательные классы на всех уровнях порога вероятности. Значение от 0.5 (случайное угадывание) до 1.0 (идеальная модель). Полезна для оценки модели на несбалансированных данных.\n",
"- **Cohen's Kappa** – измеряет степень согласия между предсказаниями модели и истинными метками с учётом случайного угадывания. Значения варьируются от -1 (полное несогласие) до 1 (идеальное согласие). Удобна для оценки на несбалансированных данных.\n",
"- **MCC (Matthews Correlation Coefficient)** – метрика корреляции между предсказаниями и истинными классами, учитывающая все типы ошибок (TP, TN, FP, FN). Значение варьируется от -1 (полная несоответствие) до 1 (идеальное совпадение). Отлично подходит для задач с несбалансированными классами.\n",
"- **Confusion Matrix (матрица ошибок)** – матрица ошибок отражает распределение предсказаний модели по каждому из классов.\n",
"\n",
"---\n",
"\n",
"### Выбор ориентира:\n",
"В качестве базовой модели для оценки качества предсказаний выбрано использование самой распространённой категории целевой переменной (\"Transaction\") в обучающей выборке. Этот подход, известный как \"most frequent class baseline\", заключается в том, что модель всегда предсказывает наиболее часто встречающийся тип транзакции.\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Разбиение данных:"
]
},
{
"cell_type": "code",
"execution_count": 389,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Самый частый класс: Sale\n",
"Baseline Accuracy: 0.59375\n",
"Baseline F1: 0.4424019607843137\n"
]
}
],
"source": [
"from sklearn.metrics import accuracy_score, f1_score\n",
"\n",
"\n",
"# Определяем целевой признак и входные признаки\n",
"y_feature: str = 'Transaction'\n",
"X_features: list[str] = df.drop(columns=y_feature, axis=1).columns.tolist()\n",
"\n",
"# Разбиваем данные на обучающую и тестовую выборки\n",
"X_df_train, X_df_test, y_df_train, y_df_test = split_into_train_test(\n",
" df, \n",
" stratify_colname=y_feature, \n",
" frac_train=0.8, \n",
" random_state=42 \n",
")\n",
"\n",
"# Определяем самый частый класс\n",
"most_frequent_class = y_df_train.mode().values[0][0]\n",
"print(f\"Самый частый класс: {most_frequent_class}\")\n",
"\n",
"# Вычисляем предсказания базовой модели (все предсказания равны самому частому классу)\n",
"baseline_predictions: list[str] = [most_frequent_class] * len(y_df_test)\n",
"\n",
"# Оцениваем базовую модель\n",
"print('Baseline Accuracy:', accuracy_score(y_df_test, baseline_predictions))\n",
"print('Baseline F1:', f1_score(y_df_test, baseline_predictions, average='weighted'))\n",
"\n",
"# Унитарное кодирование для целевого признака\n",
"y_df_train = y_df_train['Transaction'].map({'Sale': 1, 'Option Exercise': 0})\n",
"y_df_test = y_df_test['Transaction'].map({'Sale': 1, 'Option Exercise': 0})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Выбор моделей обучения:\n",
"\n",
"Для обучения были выбраны следующие модели:\n",
"1. **Случайный лес (Random Forest)**: Ансамблевая модель, которая использует множество решающих деревьев. Она хорошо справляется с нелинейными зависимостями и шумом в данных, а также обладает устойчивостью к переобучению.\n",
"2. **Логистическая регрессия (Logistic Regression)**: Статистический метод для бинарной классификации, который моделирует зависимость между целевой переменной и независимыми признаками, используя логистическую функцию. Она проста в интерпретации и быстра в обучении.\n",
"3. **Метод ближайших соседей (KNN)**: Алгоритм классификации, который предсказывает класс на основе ближайших k обучающих примеров. KNN интуитивно понятен и не требует обучения, но может быть медленным на больших данных и чувствительным к выбору параметров.\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Построение конвейера:\n",
"\n",
"Конвейеры для обработки числовых и категориальных значений, а так же основной конвейер уже были построены ранее при решении задачи регрессии."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Демонстрация работы конвейера:"
]
},
{
"cell_type": "code",
"execution_count": 390,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Cost</th>\n",
" <th>Shares</th>\n",
" <th>Value ($)</th>\n",
" <th>Shares Total</th>\n",
" <th>Year</th>\n",
" <th>Month</th>\n",
" <th>Day</th>\n",
" <th>Insider Trading_DENHOLM ROBYN M</th>\n",
" <th>Insider Trading_Kirkhorn Zachary</th>\n",
" <th>Insider Trading_Musk Elon</th>\n",
" <th>Insider Trading_Musk Kimbal</th>\n",
" <th>Insider Trading_Taneja Vaibhav</th>\n",
" <th>Insider Trading_Wilson-Thompson Kathleen</th>\n",
" <th>Relationship_Chief Accounting Officer</th>\n",
" <th>Relationship_Chief Financial Officer</th>\n",
" <th>Relationship_Director</th>\n",
" <th>Relationship_SVP Powertrain and Energy Eng.</th>\n",
" <th>Transaction_Sale</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-0.966516</td>\n",
" <td>-0.361759</td>\n",
" <td>-0.450022</td>\n",
" <td>-0.343599</td>\n",
" <td>0.715678</td>\n",
" <td>-0.506108</td>\n",
" <td>-0.400623</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>-1.074894</td>\n",
" <td>1.225216</td>\n",
" <td>-0.414725</td>\n",
" <td>-0.319938</td>\n",
" <td>-1.397276</td>\n",
" <td>0.801338</td>\n",
" <td>0.906673</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>-1.074894</td>\n",
" <td>1.211753</td>\n",
" <td>-0.415027</td>\n",
" <td>-0.320141</td>\n",
" <td>-1.397276</td>\n",
" <td>1.062828</td>\n",
" <td>-0.098939</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.167142</td>\n",
" <td>0.037499</td>\n",
" <td>1.023612</td>\n",
" <td>-0.325853</td>\n",
" <td>-1.397276</td>\n",
" <td>1.062828</td>\n",
" <td>-0.501184</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1.217886</td>\n",
" <td>-0.075287</td>\n",
" <td>0.632973</td>\n",
" <td>-0.330205</td>\n",
" <td>-1.397276</td>\n",
" <td>1.062828</td>\n",
" <td>-0.501184</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0.505872</td>\n",
" <td>-0.361021</td>\n",
" <td>-0.443679</td>\n",
" <td>-0.343698</td>\n",
" <td>0.715678</td>\n",
" <td>-0.767598</td>\n",
" <td>1.308918</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>-1.088674</td>\n",
" <td>-0.357532</td>\n",
" <td>-0.450389</td>\n",
" <td>-0.342863</td>\n",
" <td>0.715678</td>\n",
" <td>0.278360</td>\n",
" <td>-0.903429</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>-0.692146</td>\n",
" <td>-0.355855</td>\n",
" <td>-0.445383</td>\n",
" <td>-0.343220</td>\n",
" <td>0.715678</td>\n",
" <td>0.801338</td>\n",
" <td>1.409480</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>-1.088674</td>\n",
" <td>-0.361181</td>\n",
" <td>-0.450389</td>\n",
" <td>-0.343649</td>\n",
" <td>-1.397276</td>\n",
" <td>1.062828</td>\n",
" <td>-0.903429</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>1.091997</td>\n",
" <td>-0.204531</td>\n",
" <td>0.114712</td>\n",
" <td>1.538166</td>\n",
" <td>0.715678</td>\n",
" <td>-1.029087</td>\n",
" <td>1.208357</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Cost Shares Value ($) Shares Total Year Month Day \\\n",
"0 -0.966516 -0.361759 -0.450022 -0.343599 0.715678 -0.506108 -0.400623 \n",
"1 -1.074894 1.225216 -0.414725 -0.319938 -1.397276 0.801338 0.906673 \n",
"2 -1.074894 1.211753 -0.415027 -0.320141 -1.397276 1.062828 -0.098939 \n",
"3 1.167142 0.037499 1.023612 -0.325853 -1.397276 1.062828 -0.501184 \n",
"4 1.217886 -0.075287 0.632973 -0.330205 -1.397276 1.062828 -0.501184 \n",
"5 0.505872 -0.361021 -0.443679 -0.343698 0.715678 -0.767598 1.308918 \n",
"6 -1.088674 -0.357532 -0.450389 -0.342863 0.715678 0.278360 -0.903429 \n",
"7 -0.692146 -0.355855 -0.445383 -0.343220 0.715678 0.801338 1.409480 \n",
"8 -1.088674 -0.361181 -0.450389 -0.343649 -1.397276 1.062828 -0.903429 \n",
"9 1.091997 -0.204531 0.114712 1.538166 0.715678 -1.029087 1.208357 \n",
"\n",
" Insider Trading_DENHOLM ROBYN M Insider Trading_Kirkhorn Zachary \\\n",
"0 0.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 0.0 \n",
"4 0.0 0.0 \n",
"5 0.0 0.0 \n",
"6 0.0 0.0 \n",
"7 0.0 0.0 \n",
"8 0.0 0.0 \n",
"9 0.0 0.0 \n",
"\n",
" Insider Trading_Musk Elon Insider Trading_Musk Kimbal \\\n",
"0 0.0 0.0 \n",
"1 1.0 0.0 \n",
"2 1.0 0.0 \n",
"3 1.0 0.0 \n",
"4 1.0 0.0 \n",
"5 0.0 0.0 \n",
"6 0.0 0.0 \n",
"7 0.0 0.0 \n",
"8 0.0 0.0 \n",
"9 1.0 0.0 \n",
"\n",
" Insider Trading_Taneja Vaibhav Insider Trading_Wilson-Thompson Kathleen \\\n",
"0 1.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 0.0 \n",
"4 0.0 0.0 \n",
"5 0.0 0.0 \n",
"6 1.0 0.0 \n",
"7 0.0 0.0 \n",
"8 1.0 0.0 \n",
"9 0.0 0.0 \n",
"\n",
" Relationship_Chief Accounting Officer \\\n",
"0 1.0 \n",
"1 0.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"4 0.0 \n",
"5 0.0 \n",
"6 1.0 \n",
"7 0.0 \n",
"8 1.0 \n",
"9 0.0 \n",
"\n",
" Relationship_Chief Financial Officer Relationship_Director \\\n",
"0 0.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 0.0 \n",
"4 0.0 0.0 \n",
"5 0.0 0.0 \n",
"6 0.0 0.0 \n",
"7 0.0 0.0 \n",
"8 0.0 0.0 \n",
"9 0.0 0.0 \n",
"\n",
" Relationship_SVP Powertrain and Energy Eng. Transaction_Sale \n",
"0 0.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 1.0 \n",
"4 0.0 1.0 \n",
"5 1.0 1.0 \n",
"6 0.0 0.0 \n",
"7 1.0 1.0 \n",
"8 0.0 0.0 \n",
"9 0.0 1.0 "
]
},
"execution_count": 390,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Применение конвейера\n",
"preprocessing_result = pipeline_end.fit_transform(X_df_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df.head(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Проверка моделей:\n",
"\n",
"Оценка результатов обучения:\n",
"1. **Случайный лес (Random Forest)**:\n",
" - Показатели:\n",
" - Precision (обучение): 1.0\n",
" - Precision (тест): 1.0\n",
" - Recall (обучение): 1.0\n",
" - Recall (тест): 1.0\n",
" - Accuracy (обучение): 1.0\n",
" - Accuracy (тест): 1.0\n",
" - F1 Score (обучение): 1.0\n",
" - F1 Score (тест): 1.0\n",
" - ROC AUC (тест): 1.0\n",
" - Cohen Kappa (тест): 1.0\n",
" - MCC (тест): 1.0\n",
" - Confusion Matrix (тест):\n",
" ```\n",
" [[13, 0],\n",
" [ 0, 19]]\n",
" ```\n",
" - Вывод: Случайный лес идеально справляется с задачей на обеих выборках. Однако столь высокие значения метрик на обучении и тесте могут указывать на переобучение модели.\n",
"2. **Логистическая регрессия (Logistic Regression)**:\n",
" - Показатели:\n",
" - Precision (обучение): 1.0\n",
" - Precision (тест): 1.0\n",
" - Recall (обучение): 1.0\n",
" - Recall (тест): 1.0\n",
" - Accuracy (обучение): 1.0\n",
" - Accuracy (тест): 1.0\n",
" - F1 Score (обучение): 1.0\n",
" - F1 Score (тест): 1.0\n",
" - ROC AUC (тест): 1.0\n",
" - Cohen Kappa (тест): 1.0\n",
" - MCC (тест): 1.0\n",
" - Confusion Matrix (тест):\n",
" ```\n",
" [[13, 0],\n",
" [ 0, 19]]\n",
" ```\n",
" - Вывод: Логистическая регрессия также показывает идеальные результаты. Это может быть связано с линейной разделимостью данных.\n",
"3. **Метод ближайших соседей (KNN)**:\n",
" - Показатели:\n",
" - Precision (обучение): 1.0\n",
" - Precision (тест): 1.0\n",
" - Recall (обучение): 0.95\n",
" - Recall (тест): 0.947\n",
" - Accuracy (обучение): 0.968\n",
" - Accuracy (тест): 0.969\n",
" - F1 Score (обучение): 0.974\n",
" - F1 Score (тест): 0.973\n",
" - ROC AUC (тест): 0.974\n",
" - Cohen Kappa (тест): 0.936\n",
" - MCC (тест): 0.938\n",
" - Confusion Matrix (тест):\n",
" ```\n",
" [[13, 0],\n",
" [ 1, 18]]\n",
" ```\n",
" - Вывод: Метод ближайших соседей показывает хорошие результаты, с небольшим снижением полноты на тестовой выборке. Это связано с особенностями алгоритма, который может быть чувствителен к выбросам и распределению данных.\n"
]
},
{
"cell_type": "code",
"execution_count": 391,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Модель: RandomForestClassifier\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 1.0\n",
"\tRecall_test: 1.0\n",
"\tAccuracy_train: 1.0\n",
"\tAccuracy_test: 1.0\n",
"\tF1_train: 1.0\n",
"\tF1_test: 1.0\n",
"\tROC_AUC_test: 1.0\n",
"\tCohen_kappa_test: 1.0\n",
"\tMCC_test: 1.0\n",
"\tConfusion_matrix: [[13 0]\n",
" [ 0 19]]\n",
"\n",
"Модель: LogisticRegression\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 1.0\n",
"\tRecall_test: 1.0\n",
"\tAccuracy_train: 1.0\n",
"\tAccuracy_test: 1.0\n",
"\tF1_train: 1.0\n",
"\tF1_test: 1.0\n",
"\tROC_AUC_test: 1.0\n",
"\tCohen_kappa_test: 1.0\n",
"\tMCC_test: 1.0\n",
"\tConfusion_matrix: [[13 0]\n",
" [ 0 19]]\n",
"\n",
"Модель: KNN\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 0.95\n",
"\tRecall_test: 0.9473684210526315\n",
"\tAccuracy_train: 0.967741935483871\n",
"\tAccuracy_test: 0.96875\n",
"\tF1_train: 0.9743589743589743\n",
"\tF1_test: 0.972972972972973\n",
"\tROC_AUC_test: 0.9736842105263157\n",
"\tCohen_kappa_test: 0.9359999999999999\n",
"\tMCC_test: 0.9379228369755696\n",
"\tConfusion_matrix: [[13 0]\n",
" [ 1 18]]\n",
"\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn import metrics\n",
"\n",
"\n",
"# Оценка качества различных моделей на основе метрик\n",
"def evaluate_models(models: dict[str, Any], \n",
" pipeline_end: Pipeline, \n",
" X_train: DataFrame, y_train, \n",
" X_test: DataFrame, y_test) -> dict[str, dict[str, Any]]:\n",
" results: dict[str, dict[str, Any]] = {}\n",
" \n",
" for model_name, model in models.items():\n",
" # Создание конвейера для текущей модели\n",
" model_pipeline = Pipeline(\n",
" [\n",
" (\"pipeline\", pipeline_end), \n",
" (\"model\", model),\n",
" ]\n",
" )\n",
" \n",
" # Обучение модели\n",
" model_pipeline.fit(X_train, y_train)\n",
" \n",
" # Предсказание для обучающей и тестовой выборки\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_predict = model_pipeline.predict(X_test)\n",
" \n",
" # Вычисление метрик для текущей модели\n",
" metrics_dict: dict[str, Any] = {\n",
" \"Precision_train\": metrics.precision_score(y_train, y_train_predict),\n",
" \"Precision_test\": metrics.precision_score(y_test, y_test_predict),\n",
" \"Recall_train\": metrics.recall_score(y_train, y_train_predict),\n",
" \"Recall_test\": metrics.recall_score(y_test, y_test_predict),\n",
" \"Accuracy_train\": metrics.accuracy_score(y_train, y_train_predict),\n",
" \"Accuracy_test\": metrics.accuracy_score(y_test, y_test_predict),\n",
" \"F1_train\": metrics.f1_score(y_train, y_train_predict),\n",
" \"F1_test\": metrics.f1_score(y_test, y_test_predict),\n",
" \"ROC_AUC_test\": metrics.roc_auc_score(y_test, y_test_predict),\n",
" \"Cohen_kappa_test\": metrics.cohen_kappa_score(y_test, y_test_predict),\n",
" \"MCC_test\": metrics.matthews_corrcoef(y_test, y_test_predict),\n",
" \"Confusion_matrix\": metrics.confusion_matrix(y_test, y_test_predict),\n",
" }\n",
" \n",
" # Сохранение результатов\n",
" results[model_name] = metrics_dict\n",
" \n",
" return results\n",
"\n",
"\n",
"# Выбранные модели для классификации\n",
"models_classification: dict[str, Any] = {\n",
" \"RandomForestClassifier\": RandomForestClassifier(random_state=42),\n",
" \"LogisticRegression\": LogisticRegression(max_iter=1000),\n",
" \"KNN\": KNeighborsClassifier(),\n",
"}\n",
"\n",
"results: dict[str, dict[str, Any]] = evaluate_models(models_classification,\n",
" pipeline_end,\n",
" X_df_train, y_df_train,\n",
" X_df_test, y_df_test)\n",
"\n",
"# Вывод результатов\n",
"for model_name, metrics_dict in results.items():\n",
" print(f\"Модель: {model_name}\")\n",
" for metric_name, value in metrics_dict.items():\n",
" print(f\"\\t{metric_name}: {value}\")\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Матрица ошибок:"
]
},
{
"cell_type": "code",
"execution_count": 392,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABEoAAAQTCAYAAABzx8zfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAADcLUlEQVR4nOzdeZyN5f/H8feZYfaNDGMYY9/CiCLZQ0NRJAmFhCIhEZJdTWRpsf5aUKlEQpQ1+1IpQyRZs8s+xjbLuX9/zJnz7ZjFGXNus3g9H4/7kXPf97nOdY5p3rfPua7rthiGYQgAAAAAAAByy+oOAAAAAAAAZBcUSgAAAAAAAGwolAAAAAAAANhQKAEAAAAAALChUAIAAAAAAGBDoQQAAAAAAMCGQgkAAAAAAIBNnqzuAAAAyP6uX7+uuLg4l7bp4eEhLy8vl7YJAMDdgmw2D4USAACQruvXr6tEuJ9O/Zvo0nZDQkJ06NAhLsgAAMggstlcFEoAAEC64uLidOrfRP3zW3EF+Ltm1m7MZavCqx9WXFzcXX8xBgBARpHN5qJQAgAAnOLnb5Gfv8UlbVnlmnYAALibkc3mYDFXAAAAAAAAG0aUAAAApyQaViUarmsLAABkDtlsDgolAADAKVYZsso1V2OuagcAgLsZ2WwOpt4AAAAAAADYMKIEAAA4xSqrXDUo13UtAQBw9yKbzUGhBAAAOCXRMJRouGZYrqvaAQDgbkY2m4OpNwAAAAAAADYUSgDcUufOnVW8ePGs7sZd5fDhw7JYLJo1a1aW9aF48eLq3Lmzw759+/bpkUceUWBgoCwWixYuXKhZs2bJYrHo8OHDWdJP3DnJC8a5agNwZzVo0EANGjRwWXup5QScs3btWlksFq1duzaru4Icjmw2B4USIJtJ/kdn8pYnTx4VKVJEnTt31vHjx7O6e3dc586dHT6P/27Lli3L6u6lcOLECY0YMULR0dFpnrN27Vo9+eSTCgkJkYeHhwoWLKgWLVpowYIFd66jt6lTp076448/9NZbb+nzzz/X/fffn9VdAoAcKTnvt23bltVdSdfmzZs1YsQIXbx40SXtJX8RkLy5ubkpf/78atasmbZs2eKS1wCAzGKNEiCbGjVqlEqUKKHr169r69atmjVrljZu3Khdu3bJy8srq7t3R3l6eurjjz9OsT8iIiILepO+EydOaOTIkSpevLiqVq2a4vjw4cM1atQolSlTRi+++KLCw8N17tw5/fDDD2rdurXmzJmj9u3b3/mOp2Lv3r1yc/tfPf3atWvasmWLhgwZol69etn3P/fcc3rmmWfk6emZFd3EHWSVoURuQQjkWCtWrMjwczZv3qyRI0eqc+fOCgoKcjh2c05kRLt27fToo48qMTFRf//9t6ZOnaqGDRvq119/VeXKlW+rzZykXr16unbtmjw8PLK6K8jhyGZzUCgBsqlmzZrZv63v2rWrChQooLFjx2rx4sV6+umns7h3d1aePHn07LPPmtL21atX5ePjY0rbN5s/f75GjRqlp556Sl9++aXy5s1rPzZgwAAtX75c8fHxd6Qvzri58HHmzBlJSnGh7O7uLnd3d5e97pUrV+Tr6+uy9uA6rhyWy8UYcOe5+h/lmSmQV6tWzSHb69atq2bNmmnatGmaOnWqK7rntKzIHTc3t7vuiy+Yg2w2B1NvgByibt26kqQDBw5IkuLi4jRs2DBVr15dgYGB8vX1Vd26dbVmzRqH5yUPcR0/frz+7//+T6VKlZKnp6ceeOAB/frrryleZ+HChapUqZK8vLxUqVIlfffdd6n258qVK3rttdcUFhYmT09PlStXTuPHj5dx02rZFotFvXr10rx581SxYkV5e3urVq1a+uOPPyRJM2bMUOnSpeXl5aUGDRrc9joXU6dO1b333itPT0+Fhobq5ZdfTjFMuEGDBqpUqZJ+++031atXTz4+PnrjjTckSTdu3NDw4cNVunRpeXp6KiwsTK+//rpu3Ljh0MbKlStVp04dBQUFyc/PT+XKlbO3sXbtWj3wwAOSpOeff94+rDh5nZGhQ4cqf/78+vTTTx2KJMkiIyPVvHnzNN/jzp071blzZ5UsWVJeXl4KCQlRly5ddO7cOYfzLl++rL59+6p48eLy9PRUwYIF1aRJE/3+++/2c/bt26fWrVsrJCREXl5eKlq0qJ555hldunTJfs5/556PGDFC4eHhkpKKOhaLxb5uTVprlPz444+qW7eufH195e/vr8cee0y7d+92OKdz587y8/PTgQMH9Oijj8rf318dOnRI8zMAgLvJ9u3b1axZMwUEBMjPz0+NGjXS1q1bU5y3c+dO1a9fX97e3ipatKjGjBmjmTNnpvjdnNoaJR9++KHuvfde+fj4KF++fLr//vv15ZdfSkr63T9gwABJUokSJey5ltxmamuUXLx4Ua+++qo9g4oWLaqOHTvq7Nmz6b7Xm69z/tte37597dcbpUuX1tixY2W1Ot7G9Ny5c3ruuecUEBCgoKAgderUSTt27Eix3ld6uWO1WvXee+/p3nvvlZeXlwoVKqQXX3xRFy5ccHitbdu2KTIyUgUKFJC3t7dKlCihLl26OJzz9ddfq3r16vL391dAQIAqV66s999/3348rTVK5s2bp+rVq8vb21sFChTQs88+m2LqdfJ7OH78uFq2bCk/Pz8FBwerf//+SkxMTPdzBuAcRpQAOUTyRUm+fPkkSTExMfr444/Vrl07devWTZcvX9Ynn3yiyMhI/fLLLymmfXz55Ze6fPmyXnzxRVksFo0bN05PPvmkDh48aP9H+4oVK9S6dWtVrFhRUVFROnfunJ5//nkVLVrUoS3DMPT4449rzZo1euGFF1S1alUtX75cAwYM0PHjxzVp0iSH8zds2KDFixfr5ZdfliRFRUWpefPmev311zV16lT17NlTFy5c0Lhx49SlSxf99NNPKd7/zRdYefPmVWBgoKSkC7mRI0eqcePG6tGjh/bu3atp06bp119/1aZNmxyKEufOnVOzZs30zDPP6Nlnn1WhQoVktVr1+OOPa+PGjerevbsqVKigP/74Q5MmTdLff/+thQsXSpJ2796t5s2bq0qVKho1apQ8PT21f/9+bdq0SZJUoUIFjRo1SsOGDVP37t3tF30PPfSQ9u3bp7/++ktdunSRv7+/U3/nN1u5cqUOHjyo559/XiEhIdq9e7f+7//+T7t379bWrVtlsVgkSS+99JLmz5+vXr16qWLFijp37pw2btyoPXv2qFq1aoqLi1NkZKRu3LihV155RSEhITp+/LiWLFmiixcv2j/X/3ryyScVFBSkV1991T5c2s/PL82+fv755+rUqZMiIyM1duxYXb16VdOmTVOdOnW0fft2h8WBExISFBkZqTp16mj8+PF3bIQPMo5bEAJ3zu7du1W3bl0FBATo9ddfV968eTVjxgw1aNBA69atU82aNSVJx48fV8OGDWWxWDR48GD5+vrq448/dmq0x0cffaTevXvrqaeeUp8+fXT9+nXt3LlTP//8s9q3b68nn3xSf//9t7766itNmjRJBQoUkCQFBwen2l5sbKzq1q2rPXv2qEuXLqpWrZrOnj2rxYsX69ixY/bnp+bm6xwpadRn/fr1dfz4cb344osqVqyYNm/erMGDB+vkyZN67733JCUVOFq0aKFffvlFPXr0UPny5bVo0SJ16tQp1ddKK3defPFFzZo1S88//7x69+6tQ4cOafLkydq+fbv9euLff//VI488ouDgYA0aNEhBQUE6fPiwwzpjK1euVLt27dSoUSONHTtWkrRnzx5t2rRJffr0SfMzSH7tBx54QFFRUTp9+rTef/99bdq0Sdu3b3cY0ZmYmKjIyEjVrFlT48eP16pVqzRhwgSVKlVKPXr0SPM1kPuQzSYxAGQrM2fONCQZq1atMs6cOWMcPXrUmD9/vhEcHGx4enoaR48eNQzDMBISEowbN244PPfChQtGoUKFjC5dutj3HTp0yJBk3HPPPcb58+ft+xctWmRIMr7//nv7vqpVqxqFCxc2Ll68aN+3YsUKQ5IRHh5u37dw4UJDkjFmzBiH13/qqac
"text/plain": [
"<Figure size 1200x1000 with 7 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"\n",
"_, ax = plt.subplots(ceil(len(models_classification) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"\n",
"for index, key in enumerate(models_classification.keys()):\n",
" c_matrix = results[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Sale\", \"Option Exercise\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Подбор гиперпараметров:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие параметры: {'model__criterion': 'gini', 'model__max_depth': 5, 'model__max_features': 'sqrt', 'model__n_estimators': 10}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
}
],
"source": [
"# Создание конвейера\n",
"pipeline = Pipeline([\n",
" (\"processing\", pipeline_end),\n",
" (\"model\", RandomForestClassifier(random_state=42))\n",
"])\n",
"\n",
"# Установка параметров для поиска по сетке\n",
"param_grid: dict[str, Any] = {\n",
" \"model__n_estimators\": [10, 50, 100],\n",
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
" \"model__max_depth\": [5, 7, 10],\n",
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью поиска по сетке\n",
"grid_search = GridSearchCV(estimator=pipeline, \n",
" param_grid=param_grid,\n",
" n_jobs=-1)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"grid_search.fit(X_df_train, y_df_train)\n",
"\n",
"# Результаты подбора гиперпараметров\n",
"print(\"Лучшие параметры:\", grid_search.best_params_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Сравнение наборов гиперпараметров:\n",
"\n",
"Результаты анализа показывают, что как стоковая модель, так и оптимизированная модель демонстрируют идентичные показатели качества, включая абсолютные значения всех ключевых метрик (Precision, Recall, Accuracy, F1-Score и другие), равные 1.0 на обеих выборках (обучающей и тестовой). Это указывает на то, что о б е модели идеально справляются с задачей классификации."
]
},
{
"cell_type": "code",
"execution_count": 401,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Стоковая модель:\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 1.0\n",
"\tRecall_test: 1.0\n",
"\tAccuracy_train: 1.0\n",
"\tAccuracy_test: 1.0\n",
"\tF1_train: 1.0\n",
"\tF1_test: 1.0\n",
"\tROC_AUC_test: 1.0\n",
"\tCohen_kappa_test: 1.0\n",
"\tMCC_test: 1.0\n",
"\tConfusion_matrix: [[13 0]\n",
" [ 0 19]]\n",
"\n",
"Оптимизированная модель:\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 1.0\n",
"\tRecall_test: 1.0\n",
"\tAccuracy_train: 1.0\n",
"\tAccuracy_test: 1.0\n",
"\tF1_train: 1.0\n",
"\tF1_test: 1.0\n",
"\tROC_AUC_test: 1.0\n",
"\tCohen_kappa_test: 1.0\n",
"\tMCC_test: 1.0\n",
"\tConfusion_matrix: [[13 0]\n",
" [ 0 19]]\n"
]
}
],
"source": [
"# Обучение модели с о старыми гипермараметрами\n",
"pipeline.fit(X_df_train, y_df_train)\n",
"\n",
"# Предсказание для обучающей и тестовой выборки\n",
"y_train_predict = pipeline.predict(X_df_train)\n",
"y_test_predict = pipeline.predict(X_df_test)\n",
" \n",
"# Вычисление метрик для модели с о старыми гипермараметрами\n",
"base_model_metrics: dict[str, Any] = {\n",
" \"Precision_train\": metrics.precision_score(y_df_train, y_train_predict),\n",
" \"Precision_test\": metrics.precision_score(y_df_test, y_test_predict),\n",
" \"Recall_train\": metrics.recall_score(y_df_train, y_train_predict),\n",
" \"Recall_test\": metrics.recall_score(y_df_test, y_test_predict),\n",
" \"Accuracy_train\": metrics.accuracy_score(y_df_train, y_train_predict),\n",
" \"Accuracy_test\": metrics.accuracy_score(y_df_test, y_test_predict),\n",
" \"F1_train\": metrics.f1_score(y_df_train, y_train_predict),\n",
" \"F1_test\": metrics.f1_score(y_df_test, y_test_predict),\n",
" \"ROC_AUC_test\": metrics.roc_auc_score(y_df_test, y_test_predict),\n",
" \"Cohen_kappa_test\": metrics.cohen_kappa_score(y_df_test, y_test_predict),\n",
" \"MCC_test\": metrics.matthews_corrcoef(y_df_test, y_test_predict),\n",
" \"Confusion_matrix\": metrics.confusion_matrix(y_df_test, y_test_predict),\n",
"}\n",
"\n",
"# Модель с новыми гипермараметрами\n",
"optimized_model = RandomForestClassifier(\n",
" random_state=42,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"sqrt\",\n",
" n_estimators=10,\n",
")\n",
"\n",
"# Создание конвейера для модели с новыми гипермараметрами\n",
"optimized_model_pipeline = Pipeline(\n",
" [\n",
" (\"pipeline\", pipeline_end), \n",
" (\"model\", optimized_model),\n",
" ]\n",
")\n",
" \n",
"# Обучение модели с новыми гипермараметрами\n",
"optimized_model_pipeline.fit(X_df_train, y_df_train)\n",
" \n",
"# Предсказание для обучающей и тестовой выборки\n",
"y_train_predict = optimized_model_pipeline.predict(X_df_train)\n",
"y_test_predict = optimized_model_pipeline.predict(X_df_test)\n",
" \n",
"# Вычисление метрик для модели с новыми гипермараметрами\n",
"optimized_model_metrics: dict[str, Any] = {\n",
" \"Precision_train\": metrics.precision_score(y_df_train, y_train_predict),\n",
" \"Precision_test\": metrics.precision_score(y_df_test, y_test_predict),\n",
" \"Recall_train\": metrics.recall_score(y_df_train, y_train_predict),\n",
" \"Recall_test\": metrics.recall_score(y_df_test, y_test_predict),\n",
" \"Accuracy_train\": metrics.accuracy_score(y_df_train, y_train_predict),\n",
" \"Accuracy_test\": metrics.accuracy_score(y_df_test, y_test_predict),\n",
" \"F1_train\": metrics.f1_score(y_df_train, y_train_predict),\n",
" \"F1_test\": metrics.f1_score(y_df_test, y_test_predict),\n",
" \"ROC_AUC_test\": metrics.roc_auc_score(y_df_test, y_test_predict),\n",
" \"Cohen_kappa_test\": metrics.cohen_kappa_score(y_df_test, y_test_predict),\n",
" \"MCC_test\": metrics.matthews_corrcoef(y_df_test, y_test_predict),\n",
" \"Confusion_matrix\": metrics.confusion_matrix(y_df_test, y_test_predict),\n",
"}\n",
"\n",
"# Вывод информации\n",
"print('Стоковая модель:')\n",
"for metric_name, value in base_model_metrics.items():\n",
" print(f\"\\t{metric_name}: {value}\")\n",
"\n",
"print('\\nО птимизир о ва нна я модель:')\n",
"for metric_name, value in optimized_model_metrics.items():\n",
" print(f\"\\t{metric_name}: {value}\")"
]
2024-11-16 21:46:44 +04:00
}
],
"metadata": {
"kernelspec": {
"display_name": "aimenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}