2164 lines
288 KiB
Plaintext
Raw Normal View History

2024-11-30 09:49:46 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Бизнес-цели:\n",
"1. Для решения задачи регрессии: Предсказать будущую стоимость акций компании Tesla на основе инсайдерских транзакций. Стоимость акций (\"Cost\") зависит от множества факторов, включая объём и тип транзакций, совершаемых инсайдерами. Если выявить зависимости между параметрами транзакций (количество акций, общий объём сделки, должность инсайдера) и стоимостью акций, это может помочь инвесторам принимать обоснованные решения о покупке или продаже.\n",
"2. Для решения задачи классификации: Классифицировать тип инсайдерской транзакции (продажа акций или исполнение опционов) на основе характеристик сделки. Тип транзакции (\"Transaction\") может быть индикатором доверия инсайдера к текущей рыночной цене или будущей прибыльности компании. Модель, которая предсказывает тип транзакции, может помочь в оценке поведения инсайдеров и выявлении аномалий."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выгрузка данных из файла в датафрейм"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, Tuple\n",
"from math import ceil\n",
"from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor\n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.model_selection import cross_val_score\n",
"\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"import matplotlib.pyplot as plt\n",
"\n",
"df: DataFrame = pd.read_csv(\"static/csv/TSLA.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Преобразование данных"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Выборка данных:\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Insider Trading</th>\n",
" <th>Relationship</th>\n",
" <th>Transaction</th>\n",
" <th>Cost</th>\n",
" <th>Shares</th>\n",
" <th>Value ($)</th>\n",
" <th>Shares Total</th>\n",
" <th>Year</th>\n",
" <th>Month</th>\n",
" <th>Day</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Kirkhorn Zachary</td>\n",
" <td>Chief Financial Officer</td>\n",
" <td>Sale</td>\n",
" <td>196.72</td>\n",
" <td>10455</td>\n",
" <td>2056775</td>\n",
" <td>203073</td>\n",
" <td>2022</td>\n",
" <td>3</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Taneja Vaibhav</td>\n",
" <td>Chief Accounting Officer</td>\n",
" <td>Sale</td>\n",
" <td>195.79</td>\n",
" <td>2466</td>\n",
" <td>482718</td>\n",
" <td>100458</td>\n",
" <td>2022</td>\n",
" <td>3</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Baglino Andrew D</td>\n",
" <td>SVP Powertrain and Energy Eng.</td>\n",
" <td>Sale</td>\n",
" <td>195.79</td>\n",
" <td>1298</td>\n",
" <td>254232</td>\n",
" <td>65547</td>\n",
" <td>2022</td>\n",
" <td>3</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Taneja Vaibhav</td>\n",
" <td>Chief Accounting Officer</td>\n",
" <td>Option Exercise</td>\n",
" <td>0.00</td>\n",
" <td>7138</td>\n",
" <td>0</td>\n",
" <td>102923</td>\n",
" <td>2022</td>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Baglino Andrew D</td>\n",
" <td>SVP Powertrain and Energy Eng.</td>\n",
" <td>Option Exercise</td>\n",
" <td>0.00</td>\n",
" <td>2586</td>\n",
" <td>0</td>\n",
" <td>66845</td>\n",
" <td>2022</td>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>Kirkhorn Zachary</td>\n",
" <td>Chief Financial Officer</td>\n",
" <td>Option Exercise</td>\n",
" <td>0.00</td>\n",
" <td>16867</td>\n",
" <td>0</td>\n",
" <td>213528</td>\n",
" <td>2022</td>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>Baglino Andrew D</td>\n",
" <td>SVP Powertrain and Energy Eng.</td>\n",
" <td>Option Exercise</td>\n",
" <td>20.91</td>\n",
" <td>10500</td>\n",
" <td>219555</td>\n",
" <td>74759</td>\n",
" <td>2022</td>\n",
" <td>2</td>\n",
" <td>27</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>Baglino Andrew D</td>\n",
" <td>SVP Powertrain and Energy Eng.</td>\n",
" <td>Sale</td>\n",
" <td>202.00</td>\n",
" <td>10500</td>\n",
" <td>2121000</td>\n",
" <td>64259</td>\n",
" <td>2022</td>\n",
" <td>2</td>\n",
" <td>27</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>Kirkhorn Zachary</td>\n",
" <td>Chief Financial Officer</td>\n",
" <td>Sale</td>\n",
" <td>193.00</td>\n",
" <td>3750</td>\n",
" <td>723750</td>\n",
" <td>196661</td>\n",
" <td>2022</td>\n",
" <td>2</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>Baglino Andrew D</td>\n",
" <td>SVP Powertrain and Energy Eng.</td>\n",
" <td>Option Exercise</td>\n",
" <td>20.91</td>\n",
" <td>10500</td>\n",
" <td>219555</td>\n",
" <td>74759</td>\n",
" <td>2022</td>\n",
" <td>1</td>\n",
" <td>27</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Insider Trading Relationship Transaction Cost \\\n",
"0 Kirkhorn Zachary Chief Financial Officer Sale 196.72 \n",
"1 Taneja Vaibhav Chief Accounting Officer Sale 195.79 \n",
"2 Baglino Andrew D SVP Powertrain and Energy Eng. Sale 195.79 \n",
"3 Taneja Vaibhav Chief Accounting Officer Option Exercise 0.00 \n",
"4 Baglino Andrew D SVP Powertrain and Energy Eng. Option Exercise 0.00 \n",
"5 Kirkhorn Zachary Chief Financial Officer Option Exercise 0.00 \n",
"6 Baglino Andrew D SVP Powertrain and Energy Eng. Option Exercise 20.91 \n",
"7 Baglino Andrew D SVP Powertrain and Energy Eng. Sale 202.00 \n",
"8 Kirkhorn Zachary Chief Financial Officer Sale 193.00 \n",
"9 Baglino Andrew D SVP Powertrain and Energy Eng. Option Exercise 20.91 \n",
"\n",
" Shares Value ($) Shares Total Year Month Day \n",
"0 10455 2056775 203073 2022 3 6 \n",
"1 2466 482718 100458 2022 3 6 \n",
"2 1298 254232 65547 2022 3 6 \n",
"3 7138 0 102923 2022 3 5 \n",
"4 2586 0 66845 2022 3 5 \n",
"5 16867 0 213528 2022 3 5 \n",
"6 10500 219555 74759 2022 2 27 \n",
"7 10500 2121000 64259 2022 2 27 \n",
"8 3750 723750 196661 2022 2 6 \n",
"9 10500 219555 74759 2022 1 27 "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Преобразование типов данных\n",
"df['Insider Trading'] = df['Insider Trading'].astype('category') # Преобразование в категорию\n",
"df['Relationship'] = df['Relationship'].astype('category') # Преобразование в категорию\n",
"df['Transaction'] = df['Transaction'].astype('category') # Преобразование в категорию\n",
"df['Cost'] = pd.to_numeric(df['Cost'], errors='coerce') # Преобразование в float\n",
"df['Shares'] = pd.to_numeric(df['Shares'].str.replace(',', ''), errors='coerce') # Преобразование в float без запятых\n",
"df['Value ($)'] = pd.to_numeric(df['Value ($)'].str.replace(',', ''), errors='coerce') # Преобразование в float без запятых\n",
"df['Shares Total'] = pd.to_numeric(df['Shares Total'].str.replace(',', ''), errors='coerce') # Преобразование в float без запятых\n",
"\n",
"df['Date'] = pd.to_datetime(df['Date'], errors='coerce') # Преобразование в datetime\n",
"df['Year'] = df['Date'].dt.year # Год\n",
"df['Month'] = df['Date'].dt.month # Месяц\n",
"df['Day'] = df['Date'].dt.day # День\n",
"df: DataFrame = df.drop(columns=['Date', 'SEC Form 4']) # Удаление столбцов с датами\n",
"\n",
"print('Выборка данных:')\n",
"df.head(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Задача регрессии:\n",
"\n",
"Основные метрики для регрессии:\n",
"* Средняя абсолютная ошибка (Mean Absolute Error, MAE) показывает среднее абсолютное отклонение между предсказанными и фактическими значениями. Легко интерпретируется, особенно в финансовых данных, где каждая ошибка в долларах имеет значение.\n",
"* Среднеквадратичная ошибка (Mean Squared Error, MSE) показывает, насколько отклоняются прогнозы модели от истинных значений в квадрате. Подходит для оценки общего качества модели.\n",
"* Коэффициент детерминации (R²) указывает, какую долю дисперсии зависимой переменной объясняет модель. R² варьируется от 0 до 1 (чем ближе к 1, тем лучше)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В качестве базовой модели для оценки качества предсказаний выбрано использование среднего значения целевой переменной (Cost) на обучающей выборке. Это простой и интуитивно понятный метод, который служит минимальным ориентиром для сравнения с более сложными моделями. Базовая модель помогает установить начальный уровень ошибок (MAE, MSE) и показатель качества (R²), которые сложные модели должны улучшить, чтобы оправдать своё использование."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Baseline MAE: 417.78235887096776\n",
"Baseline MSE: 182476.07973024843\n",
"Baseline R²: -0.027074997920953914\n"
]
}
],
"source": [
"from pandas.core.frame import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
"\n",
"\n",
"# Разбить данные на обучающую и тестовую выборки\n",
"def split_into_train_test(\n",
" df_input: DataFrame,\n",
" stratify_colname: str = \"y\", \n",
" frac_train: float = 0.8,\n",
" random_state: int = 42,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
"\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" \n",
" if not (0 < frac_train < 1):\n",
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
" \n",
" X: DataFrame = df_input # Contains all columns.\n",
" y: DataFrame = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
"\n",
" # Split original dataframe into train and test dataframes.\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y,\n",
" test_size=(1.0 - frac_train),\n",
" random_state=random_state\n",
" )\n",
" \n",
" return X_train, X_test, y_train, y_test\n",
"\n",
"\n",
"# Определяем целевой признак и входные признаки\n",
"y_feature: str = 'Cost'\n",
"X_features: list[str] = df.drop(columns=y_feature, axis=1).columns.tolist()\n",
"\n",
"# Разбиваем данные на обучающую и тестовую выборки\n",
"X_df_train, X_df_test, y_df_train, y_df_test = split_into_train_test(\n",
" df, \n",
" stratify_colname=y_feature, \n",
" frac_train=0.8, \n",
" random_state=42 \n",
")\n",
"\n",
"# Вычисляем предсказания базовой модели (среднее значение целевой переменной)\n",
"baseline_predictions: list[float] = [y_df_train.mean()] * len(y_df_test) # type: ignore\n",
"\n",
"# Оцениваем базовую модель\n",
"print('Baseline MAE:', mean_absolute_error(y_df_test, baseline_predictions))\n",
"print('Baseline MSE:', mean_squared_error(y_df_test, baseline_predictions))\n",
"print('Baseline R²:', r2_score(y_df_test, baseline_predictions))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Для обучения были выбраны следующие модели:\n",
"\n",
"1. Случайный лес (Random Forest): Ансамблевая модель, которая использует множество решающих деревьев. Она хорошо справляется с нелинейными зависимостями и шумом в данных, а также обладает устойчивостью к переобучению.\n",
"2. Линейная регрессия (Linear Regression): Простая модель, предполагающая линейную зависимость между признаками и целевой переменной. Она быстро обучается и предоставляет легкую интерпретацию результатов.\n",
"3. Градиентный бустинг (Gradient Boosting): Мощная модель, создающая ансамбль деревьев, которые корректируют ошибки предыдущих. Эта модель эффективна для сложных наборов данных и обеспечивает высокую точность предсказаний."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Построение конвейера:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Cost</th>\n",
" <th>Shares</th>\n",
" <th>Value ($)</th>\n",
" <th>Shares Total</th>\n",
" <th>Year</th>\n",
" <th>Month</th>\n",
" <th>Day</th>\n",
" <th>Insider Trading_DENHOLM ROBYN M</th>\n",
" <th>Insider Trading_Kirkhorn Zachary</th>\n",
" <th>Insider Trading_Musk Elon</th>\n",
" <th>Insider Trading_Musk Kimbal</th>\n",
" <th>Insider Trading_Taneja Vaibhav</th>\n",
" <th>Insider Trading_Wilson-Thompson Kathleen</th>\n",
" <th>Relationship_Chief Accounting Officer</th>\n",
" <th>Relationship_Chief Financial Officer</th>\n",
" <th>Relationship_Director</th>\n",
" <th>Relationship_SVP Powertrain and Energy Eng.</th>\n",
" <th>Transaction_Sale</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-0.966516</td>\n",
" <td>-0.361759</td>\n",
" <td>-0.450022</td>\n",
" <td>-0.343599</td>\n",
" <td>0.715678</td>\n",
" <td>-0.506108</td>\n",
" <td>-0.400623</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>-1.074894</td>\n",
" <td>1.225216</td>\n",
" <td>-0.414725</td>\n",
" <td>-0.319938</td>\n",
" <td>-1.397276</td>\n",
" <td>0.801338</td>\n",
" <td>0.906673</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>-1.074894</td>\n",
" <td>1.211753</td>\n",
" <td>-0.415027</td>\n",
" <td>-0.320141</td>\n",
" <td>-1.397276</td>\n",
" <td>1.062828</td>\n",
" <td>-0.098939</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.167142</td>\n",
" <td>0.037499</td>\n",
" <td>1.023612</td>\n",
" <td>-0.325853</td>\n",
" <td>-1.397276</td>\n",
" <td>1.062828</td>\n",
" <td>-0.501184</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1.217886</td>\n",
" <td>-0.075287</td>\n",
" <td>0.632973</td>\n",
" <td>-0.330205</td>\n",
" <td>-1.397276</td>\n",
" <td>1.062828</td>\n",
" <td>-0.501184</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0.505872</td>\n",
" <td>-0.361021</td>\n",
" <td>-0.443679</td>\n",
" <td>-0.343698</td>\n",
" <td>0.715678</td>\n",
" <td>-0.767598</td>\n",
" <td>1.308918</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>-1.088674</td>\n",
" <td>-0.357532</td>\n",
" <td>-0.450389</td>\n",
" <td>-0.342863</td>\n",
" <td>0.715678</td>\n",
" <td>0.278360</td>\n",
" <td>-0.903429</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>-0.692146</td>\n",
" <td>-0.355855</td>\n",
" <td>-0.445383</td>\n",
" <td>-0.343220</td>\n",
" <td>0.715678</td>\n",
" <td>0.801338</td>\n",
" <td>1.409480</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>-1.088674</td>\n",
" <td>-0.361181</td>\n",
" <td>-0.450389</td>\n",
" <td>-0.343649</td>\n",
" <td>-1.397276</td>\n",
" <td>1.062828</td>\n",
" <td>-0.903429</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>1.091997</td>\n",
" <td>-0.204531</td>\n",
" <td>0.114712</td>\n",
" <td>1.538166</td>\n",
" <td>0.715678</td>\n",
" <td>-1.029087</td>\n",
" <td>1.208357</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Cost Shares Value ($) Shares Total Year Month Day \\\n",
"0 -0.966516 -0.361759 -0.450022 -0.343599 0.715678 -0.506108 -0.400623 \n",
"1 -1.074894 1.225216 -0.414725 -0.319938 -1.397276 0.801338 0.906673 \n",
"2 -1.074894 1.211753 -0.415027 -0.320141 -1.397276 1.062828 -0.098939 \n",
"3 1.167142 0.037499 1.023612 -0.325853 -1.397276 1.062828 -0.501184 \n",
"4 1.217886 -0.075287 0.632973 -0.330205 -1.397276 1.062828 -0.501184 \n",
"5 0.505872 -0.361021 -0.443679 -0.343698 0.715678 -0.767598 1.308918 \n",
"6 -1.088674 -0.357532 -0.450389 -0.342863 0.715678 0.278360 -0.903429 \n",
"7 -0.692146 -0.355855 -0.445383 -0.343220 0.715678 0.801338 1.409480 \n",
"8 -1.088674 -0.361181 -0.450389 -0.343649 -1.397276 1.062828 -0.903429 \n",
"9 1.091997 -0.204531 0.114712 1.538166 0.715678 -1.029087 1.208357 \n",
"\n",
" Insider Trading_DENHOLM ROBYN M Insider Trading_Kirkhorn Zachary \\\n",
"0 0.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 0.0 \n",
"4 0.0 0.0 \n",
"5 0.0 0.0 \n",
"6 0.0 0.0 \n",
"7 0.0 0.0 \n",
"8 0.0 0.0 \n",
"9 0.0 0.0 \n",
"\n",
" Insider Trading_Musk Elon Insider Trading_Musk Kimbal \\\n",
"0 0.0 0.0 \n",
"1 1.0 0.0 \n",
"2 1.0 0.0 \n",
"3 1.0 0.0 \n",
"4 1.0 0.0 \n",
"5 0.0 0.0 \n",
"6 0.0 0.0 \n",
"7 0.0 0.0 \n",
"8 0.0 0.0 \n",
"9 1.0 0.0 \n",
"\n",
" Insider Trading_Taneja Vaibhav Insider Trading_Wilson-Thompson Kathleen \\\n",
"0 1.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 0.0 \n",
"4 0.0 0.0 \n",
"5 0.0 0.0 \n",
"6 1.0 0.0 \n",
"7 0.0 0.0 \n",
"8 1.0 0.0 \n",
"9 0.0 0.0 \n",
"\n",
" Relationship_Chief Accounting Officer \\\n",
"0 1.0 \n",
"1 0.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"4 0.0 \n",
"5 0.0 \n",
"6 1.0 \n",
"7 0.0 \n",
"8 1.0 \n",
"9 0.0 \n",
"\n",
" Relationship_Chief Financial Officer Relationship_Director \\\n",
"0 0.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 0.0 \n",
"4 0.0 0.0 \n",
"5 0.0 0.0 \n",
"6 0.0 0.0 \n",
"7 0.0 0.0 \n",
"8 0.0 0.0 \n",
"9 0.0 0.0 \n",
"\n",
" Relationship_SVP Powertrain and Energy Eng. Transaction_Sale \n",
"0 0.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 1.0 \n",
"4 0.0 1.0 \n",
"5 1.0 1.0 \n",
"6 0.0 0.0 \n",
"7 1.0 1.0 \n",
"8 0.0 0.0 \n",
"9 0.0 1.0 "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.impute import SimpleImputer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.pipeline import Pipeline\n",
"\n",
"\n",
"# Числовые столбцы\n",
"num_columns: list[str] = [\n",
" column\n",
" for column in df.columns\n",
" if df[column].dtype not in (\"category\", \"object\")\n",
"]\n",
"\n",
"# Категориальные столбцы\n",
"cat_columns: list[str] = [\n",
" column\n",
" for column in df.columns\n",
" if df[column].dtype in (\"category\", \"object\")\n",
"]\n",
"\n",
"# Заполнение пропущенных значений\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"# Стандартизация\n",
"num_scaler = StandardScaler()\n",
"# Конвейер для обработки числовых данных\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"# Заполнение пропущенных значений\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"# Унитарное кодирование\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"# Конвейер для обработки категориальных данных\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"# Трансформер для предобработки признаков\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"# Основной конвейер предобработки данных\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" ]\n",
")\n",
"\n",
"# Применение конвейера\n",
"preprocessing_result = pipeline_end.fit_transform(X_df_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df.head(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение моделей:\n",
"\n",
"Оценка результатов обучения:\n",
"1. Случайный лес (Random Forest):\n",
" - Показатели:\n",
" - Средний балл: 0.9993.\n",
" - Стандартное отклонение: 0.00046.\n",
" - Вывод: Очень высокая точность, что свидетельствует о хорошей способности модели к обобщению. Низкое значение стандартного отклонения указывает на стабильность модели.\n",
"2. Линейная регрессия (Linear Regression):\n",
" - Показатели:\n",
" - Средний балл: 1.0.\n",
" - Стандартное отклонение: 0.0.\n",
" - Вывод: Идеальная точность, однако есть вероятность переобучения, так как стандартное отклонение равно 0. Это может указывать на то, что модель идеально подгоняет данные, но может не работать на новых данных.\n",
"3. Градиентный бустинг (Gradient Boosting):\n",
" - Показатели:\n",
" - Средний балл: 0.9998.\n",
" - Стандартное отклонение: 0.00014.\n",
" - Вывод: Отличные результаты с высокой точностью и низкой вариабельностью. Модель также демонстрирует хорошую устойчивость."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"e:\\aim\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"e:\\aim\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"e:\\aim\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"e:\\aim\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"e:\\aim\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"e:\\aim\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"e:\\aim\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"e:\\aim\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
"e:\\aim\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
"e:\\aim\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
"e:\\aim\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"e:\\aim\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
"e:\\aim\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Модель: Random Forest\n",
"\tmean_score: 0.9993408296679137\n",
"\tstd_dev: 0.0003334834737331458\n",
"\n",
"Модель: Linear Regression\n",
"\tmean_score: 1.0\n",
"\tstd_dev: 0.0\n",
"\n",
"Модель: Gradient Boosting\n",
"\tmean_score: 0.9997626429479803\n",
"\tstd_dev: 0.00014954088559894797\n",
"\n"
]
}
],
"source": [
"# Обучить модели\n",
"def train_models(X: DataFrame, y: DataFrame, \n",
" models: dict[str, Any]) -> dict[str, dict[str, Any]]:\n",
" results: dict[str, dict[str, Any]] = {}\n",
" \n",
" for model_name, model in models.items():\n",
" # Создание конвейера для текущей модели\n",
" model_pipeline = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"model\", model)\n",
" ]\n",
" )\n",
" \n",
" # Обучаем модель и вычисляем кросс-валидацию\n",
" scores = cross_val_score(model_pipeline, X, y, cv=5) # 5-кратная кросс-валидация\n",
" \n",
" # Вычисление метрик для текущей модели\n",
" metrics_dict: dict[str, Any] = {\n",
" \"mean_score\": scores.mean(),\n",
" \"std_dev\": scores.std()\n",
" }\n",
" \n",
" # Сохранениерезультатов\n",
" results[model_name] = metrics_dict\n",
" \n",
" return results\n",
"\n",
"\n",
"# Выбранные модели для регрессии\n",
"models_regression: dict[str, Any] = {\n",
" \"Random Forest\": RandomForestRegressor(),\n",
" \"Linear Regression\": LinearRegression(),\n",
" \"Gradient Boosting\": GradientBoostingRegressor(),\n",
"}\n",
"\n",
"results: dict[str, Any] = train_models(X_df_train, y_df_train, models_regression)\n",
"\n",
"# Вывод результатов\n",
"for model_name, metrics_dict in results.items():\n",
" print(f\"Модель: {model_name}\")\n",
" for metric_name, value in metrics_dict.items():\n",
" print(f\"\\t{metric_name}: {value}\")\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Проверка на тестовом наборе данных:\n",
"\n",
"Оценка результатов обучения:\n",
"1. Случайный лес (Random Forest):\n",
" - Показатели:\n",
" - MAE (обучение): 1.858\n",
" - MAE (тест): 4.489\n",
" - MSE (обучение): 10.959\n",
" - MSE (тест): 62.649\n",
" - R² (обучение): 0.9999\n",
" - R² (тест): 0.9997\n",
" - STD (обучение): 3.310\n",
" - STD (тест): 7.757\n",
" - Вывод: Случайный лес показывает великолепные значения R2 на обучающей и тестовой выборках, что свидетельствует о сильной способности к обобщению. Однако MAE и MSE на тестовой выборке значительно выше, чем на обучающей, что может указывать на некоторые проблемы с переобучением.\n",
"2. Линейная регрессия (Linear Regression):\n",
" - Показатели:\n",
" - MAE (обучение): 3.069e-13\n",
" - MAE (тест): 2.762e-13\n",
" - MSE (обучение): 1.437e-25\n",
" - MSE (тест): 1.196e-25\n",
" - R² (обучение): 1.0\n",
" - R² (тест): 1.0\n",
" - STD (обучение): 3.730e-13\n",
" - STD (тест): 3.444e-13\n",
" - Вывод: Высокие показатели точности и нулевые ошибки (MAE, MSE) указывают на то, что модель идеально подгоняет данные как на обучающей, так и на тестовой выборках. Однако это также может быть признаком переобучения.\n",
"3. Градиентный бустинг (Gradient Boosting):\n",
" - Показатели:\n",
" - MAE (обучение): 0.156\n",
" - MAE (тест): 3.027\n",
" - MSE (обучение): 0.075\n",
" - MSE (тест): 41.360\n",
" - R² (обучение): 0.9999996\n",
" - R² (тест): 0.9998\n",
" - STD (обучение): 0.274\n",
" - STD (тест): 6.399\n",
" - Вывод: Градиентный бустинг демонстрирует отличные результаты на обучающей выборке, однако MAE и MSE на тестовой выборке довольно высокие, что может указывать на определенное переобучение или необходимость улучшения настройки модели."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Модель: Random Forest\n",
"\tMAE_train: 1.723261290322547\n",
"\tMAE_test: 4.921812500000078\n",
"\tMSE_train: 7.688148654032178\n",
"\tMSE_test: 71.0254988287519\n",
"\tR2_train: 0.9999625112378799\n",
"\tR2_test: 0.9996002297168499\n",
"\tSTD_train: 2.7687940152490023\n",
"\tSTD_test: 8.368380620302354\n",
"\n",
"Модель: Linear Regression\n",
"\tMAE_train: 3.0690862038154006e-13\n",
"\tMAE_test: 2.761679773755077e-13\n",
"\tMSE_train: 1.4370485712253764e-25\n",
"\tMSE_test: 1.19585889812782e-25\n",
"\tR2_train: 1.0\n",
"\tR2_test: 1.0\n",
"\tSTD_train: 3.7295840825107354e-13\n",
"\tSTD_test: 3.4438670391637766e-13\n",
"\n",
"Модель: Gradient Boosting\n",
"\tMAE_train: 0.15613772760447622\n",
"\tMAE_test: 3.1034303023675966\n",
"\tMSE_train: 0.07499640211231746\n",
"\tMSE_test: 45.07615310256558\n",
"\tR2_train: 0.9999996343043813\n",
"\tR2_test: 0.9997462868014123\n",
"\tSTD_train: 0.27385470985965804\n",
"\tSTD_test: 6.690171573523703\n",
"\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"from sklearn import metrics\n",
"\n",
"\n",
"# Оценка качества различных моделей на основе метрик\n",
"def evaluate_models(models: dict[str, Any], \n",
" pipeline_end: Pipeline, \n",
" X_train: DataFrame, y_train, \n",
" X_test: DataFrame, y_test) -> dict[str, dict[str, Any]]:\n",
" results: dict[str, dict[str, Any]] = {}\n",
" \n",
" for model_name, model in models.items():\n",
" # Создание конвейера для текущей модели\n",
" model_pipeline = Pipeline(\n",
" [\n",
" (\"pipeline\", pipeline_end), \n",
" (\"model\", model),\n",
" ]\n",
" )\n",
" \n",
" # Обучение текущей модели\n",
" model_pipeline.fit(X_train, y_train)\n",
"\n",
" # Предсказание для обучающей и тестовой выборки\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_predict = model_pipeline.predict(X_test)\n",
"\n",
" # Вычисление метрик для текущей модели\n",
" metrics_dict: dict[str, Any] = {\n",
" \"MAE_train\": metrics.mean_absolute_error(y_train, y_train_predict),\n",
" \"MAE_test\": metrics.mean_absolute_error(y_test, y_test_predict),\n",
" \"MSE_train\": metrics.mean_squared_error(y_train, y_train_predict),\n",
" \"MSE_test\": metrics.mean_squared_error(y_test, y_test_predict),\n",
" \"R2_train\": metrics.r2_score(y_train, y_train_predict),\n",
" \"R2_test\": metrics.r2_score(y_test, y_test_predict),\n",
" \"STD_train\": np.std(y_train - y_train_predict),\n",
" \"STD_test\": np.std(y_test - y_test_predict),\n",
" }\n",
"\n",
" # Сохранение результатов\n",
" results[model_name] = metrics_dict\n",
" \n",
" return results\n",
"\n",
"\n",
"y_train = np.ravel(y_df_train) \n",
"y_test = np.ravel(y_df_test) \n",
"\n",
"results: dict[str, dict[str, Any]] = evaluate_models(models_regression,\n",
" pipeline_end,\n",
" X_df_train, y_train,\n",
" X_df_test, y_test)\n",
"\n",
"# Вывод результатов\n",
"for model_name, metrics_dict in results.items():\n",
" print(f\"Модель: {model_name}\")\n",
" for metric_name, value in metrics_dict.items():\n",
" print(f\"\\t{metric_name}: {value}\")\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подбор гиперпараметров:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 36 candidates, totalling 108 fits\n",
"Лучшие параметры: {'max_depth': 30, 'min_samples_split': 5, 'n_estimators': 100}\n",
"Лучший результат (MSE): 182.9564051380662\n"
]
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"\n",
"# Применение конвейера к данным\n",
"X_train_processing_result = pipeline_end.fit_transform(X_df_train)\n",
"X_test_processing_result = pipeline_end.transform(X_df_test)\n",
"\n",
"# Создание и настройка модели случайного леса\n",
"model = RandomForestRegressor()\n",
"\n",
"# Установка параметров для поиска по сетке\n",
"param_grid: dict[str, list[int | None]] = {\n",
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
" 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью поиска по сетке\n",
"grid_search = GridSearchCV(estimator=model, \n",
" param_grid=param_grid,\n",
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"grid_search.fit(X_train_processing_result, y_train)\n",
"\n",
"# Результаты подбора гиперпараметров\n",
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
"# Меняем знак, так как берем отрицательное значение среднеквадратичной ошибки\n",
"print(\"Лучший результат (MSE):\", -grid_search.best_score_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Сравнение наборов гиперпараметров:\n",
"\n",
"Результаты анализа показывают, что параметры из старой сетки обеспечивают значительно лучшее качество модели. Среднеквадратическая ошибка (MSE) на кросс-валидации для старых параметров составила 179.369, что существенно ниже, чем для новых параметров (1290.656). На тестовой выборке модель с новыми параметрами показала MSE 172.574, что сопоставимо с результатами модели со старыми параметрами, однако этот результат является случайным, так как новые параметры продемонстрировали плохую кросс-валидационную ошибку, указывая на недообучение. Таким образом, параметры из старой сетки более предпочтительны, так как они обеспечивают лучшее обобщение и меньшую ошибку."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 36 candidates, totalling 108 fits\n",
"Старые параметры: {'max_depth': 30, 'min_samples_split': 2, 'n_estimators': 50}\n",
"Лучший результат (MSE) на старых параметрах: 178.02867772065892\n",
"\n",
"Новые параметры: {'max_depth': 5, 'min_samples_split': 10, 'n_estimators': 50}\n",
"Лучший результат (MSE) на новых параметрах: 1271.4105559258353\n",
"Среднеквадратическая ошибка (MSE) на тестовых данных: 165.739398344422\n",
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 12.873981448814583\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1sAAAHWCAYAAACBjZMqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3hT1RvA8W+SZrRp0r0XZQqCIkMERUQRxC1uUcEBuLfiAhFUVBwoP/fCLQqKeyCKOBCQpbJXKd0r6R4Z9/fHbSulg7akTdq+n+fpA8k99+ZtmnHfe855j0ZRFAUhhBBCCCGEEB6l9XYAQgghhBBCCNEZSbIlhBBCCCGEEG1Aki0hhBBCCCGEaAOSbAkhhBBCCCFEG5BkSwghhBBCCCHagCRbQgghhBBCCNEGJNkSQgghhBBCiDYgyZYQQgghhBBCtAFJtoQQQgghhBCiDUiyJYQQQggh2txXX33Fxo0ba28vXbqUzZs3ey8gIdqBJFtCdAG7d+9m2rRpdO/eHZPJhNVq5fjjj+e5556jvLzc2+EJIYToAv755x9uvfVWdu7cyZ9//sl1111HcXGxt8MSok1pFEVRvB2EEKLtfP3111x44YUYjUauvPJK+vfvT1VVFb/99htLlixh8uTJvPrqq94OUwghRCeXm5vLiBEj2LVrFwATJkxgyZIlXo5KiLYlyZYQndjevXs56qijiI+P56effiImJqbO9l27dvH1119z6623eilCIYQQXUllZSX//vsvAQEB9O3b19vhCNHmZBihEJ3Yk08+SUlJCW+88Ua9RAugZ8+edRItjUbDTTfdxPvvv0+fPn0wmUwMHjyYlStX1tlv37593HDDDfTp0wd/f3/CwsK48MILSUlJqdNu4cKFaDSa2p+AgAAGDBjA66+/Xqfd5MmTCQwMrBff4sWL0Wg0rFixos79q1ev5rTTTiMoKIiAgABGjRrF77//XqfNrFmz0Gg05OXl1bn/r7/+QqPRsHDhwjqP361btzrt9u/fj7+/PxqNpt7v9e233zJy5EjMZjMWi4UzzjijWfMOap6PlStXMm3aNMLCwrBarVx55ZXYbLZ67ZvzOH///TeTJ0+uHSIaHR3N1VdfTX5+foMxdOvWrc7fpObnwOe4W7dunHnmmU3+LikpKWg0Gp566ql62/r3789JJ51Ue3vFihVoNBoWL17c6PEO/hs89NBDaLVali9fXqfd1KlTMRgMbNq0qcn4NBoNs2bNqnPfvHnz0Gg0dWJrav/Gfg6M88Dn4dlnnyUpKQl/f39GjRrFv//+W++427Zt44ILLiA0NBSTycSQIUP44osvGoxh8uTJDT7+5MmT67X99ttvGTVqFBaLBavVytChQ/nggw9qt5900kn1fu9HH30UrVZbp92vv/7KhRdeSGJiIkajkYSEBG6//fZ6w41nzZpFv379CAwMxGq1ctxxx7F06dI6bZp7rJa8/0866ST69+9fr+1TTz1V7716qNdxzeuy5vhbt27F39+fK6+8sk673377DZ1Ox/Tp0xs9FjTvOWlJ/J9//jlnnHEGsbGxGI1GevTowZw5c3C5XHX2bei1XvNZ05rPrpb+PQ5+Xa1du7b2tdpQnEajkcGDB9O3b98WvSeF6Kj8vB2AEKLtfPnll3Tv3p0RI0Y0e59ffvmFRYsWccstt2A0GnnxxRc57bTTWLNmTe1Jwtq1a/njjz+45JJLiI+PJyUlhZdeeomTTjqJLVu2EBAQUOeYzz77LOHh4RQVFfHmm28yZcoUunXrxpgxY1r8O/3000+MHz+ewYMH156Qv/XWW5x88sn8+uuvHHvssS0+ZkNmzpxJRUVFvfvfffddJk2axLhx43jiiScoKyvjpZde4oQTTmDDhg31kraG3HTTTQQHBzNr1iy2b9/OSy+9xL59+2pP/lryOMuWLWPPnj1cddVVREdHs3nzZl599VU2b97Mn3/+We+EB2DkyJFMnToVUE8wH3vssdY/UW3kwQcf5Msvv+Saa67hn3/+wWKx8P333/Paa68xZ84cjj766BYdz263M3fu3Bbtc+qpp9Y78X766acbTIzfeecdiouLufHGG6moqOC5557j5JNP5p9//iEqKgqAzZs3c/zxxxMXF8e9996L2Wzm448/5txzz2XJkiWcd9559Y5rNBrrXJy49tpr67VZuHAhV199NUceeST33XcfwcHBbNiwge+++47LLruswd/trbfe4sEHH+Tpp5+u0+aTTz6hrKyM66+/nrCwMNasWcOCBQtIS0vjk08+qW1XWlrKeeedR7du3SgvL2fhwoWcf/75rFq1qvY92Nxj+Yq+ffsyZ84c7r77bi644ALOPvtsSktLmTx5MkcccQSzZ89ucv/mPCctsXDhQgIDA7njjjsIDAzkp59+YubMmRQVFTFv3rwWH88Tn13NcaiktEZr3pNCdEiKEKJTKiwsVADlnHPOafY+gAIof/31V+19+/btU0wmk3LeeefV3ldWVlZv31WrVimA8s4779Te99ZbbymAsnfv3tr7duzYoQDKk08+WXvfpEmTFLPZXO+Yn3zyiQIoP//8s6IoiuJ2u5VevXop48aNU9xud514kpOTlVNPPbX2voceekgBlNzc3DrHXLt2rQIob731Vp3HT0pKqr3977//KlqtVhk/fnyd+IuLi5Xg4GBlypQpdY6ZlZWlBAUF1bv/YDXPx+DBg5Wqqqra+5988kkFUD7//PMWP05Df4sPP/xQAZSVK1fW2xYXF6dcddVVtbd//vnnOs+xoihKUlKScsYZZzT5u+zdu1cBlHnz5tXbduSRRyqjRo2q9xiffPJJo8c7+G+gKIryzz//KAaDQbn22msVm82mxMXFKUOGDFEcDkeTsSmK+lp+6KGHam/fc889SmRkpDJ48OA6sTW1/4033ljv/jPOOKNOnDXPg7+/v5KWllZ7/+rVqxVAuf3222vvO+WUU5QBAwYoFRUVtfe53W5lxIgRSq9eveo91mWXXaYEBgbWuc9sNiuTJk2qvW232xWLxaIMGzZMKS8vr9P2wPfIqFGjan/vr7/+WvHz81PuvPPOeo/Z0Otp7ty5ikajUfbt21dvW42cnBwFUJ566qkWH6u57/+a3+PII4+s13bevHn1PmsO9Tpu6LXvcrmUE044QYmKilLy8vKUG2+8UfHz81PWrl3b6HEa09Bz0pL4G3r+pk2bpgQEBNR5DWk0GmXmzJl12h382duSz5SW/j0OfD998803CqCcdtppysGnmIf7nhSio5JhhEJ0UkVFRQBYLJYW7Td8+HAGDx5cezsxMZFzzjmH77//vnb4ir+/f+12h8NBfn4+PXv2JDg4mPXr19c7ps1mIy8vjz179vDss8+i0+kYNWpUvXZ5eXl1fg6uUrVx40Z27tzJZZddRn5+fm270tJSTjnlFFauXInb7a6zT0FBQZ1jFhYWHvI5uO+++xg0aBAXXnhhnfuXLVuG3W7n0ksvrXNMnU7HsGHD+Pnnnw95bFCHwun1+trb119/PX5+fnzzzTctfpwD/xYVFRXk5eVx3HHHATT4t6iqqsJoNB4yRofDQV5eHvn5+TidzkbblZWV1fu7HTzMqUZxcTF5eXnY7fZDPj6owxEffvhhXn/9dcaNG0deXh5vv/02fn4tG5SRnp7OggULmDFjRoPDozzh3HPPJS4urvb2sccey7Bhw2r/pgUFBfz0009cdNFFtc9DzfM7btw4du7cSXp6ep1jVlRUYDKZmnzcZcuWUVxczL333luvbUO9mmvWrOGiiy7i/PPPb7B35MDXU2lpKXl5eYwYMQJFUdiwYUOdtjWvkd27d/P444+j1Wo5/vjjW3UsOPT7v4bL5arXtqysrMG2zX0d19BqtSxcuJCSkhLGjx/Piy++yH333ceQIUMOue+Bj9fYc9KS+A98/mpeMyNHjqSsrIxt27bVbouMjCQtLa3JuFrz2dXcv0cNRVG47777OP/88xk2bFiTbdvjPSmEr5BhhEJ0UlarFaDFZXV79epV777evXtTVlZGbm4u0dHRlJeXM3fuXN566y3S09NRDqiz01AyM2jQoNr
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Установка параметров для поиска по сетке для старых значений\n",
"old_param_grid: dict[str, list[int | None]] = {\n",
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
" 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью поиска по сетке для старых параметров\n",
"old_grid_search = GridSearchCV(estimator=model, \n",
" param_grid=old_param_grid,\n",
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"old_grid_search.fit(X_train_processing_result, y_train)\n",
"\n",
"# Результаты подбора для старых параметров\n",
"old_best_params = old_grid_search.best_params_\n",
"# Меняем знак, так как берем отрицательное значение MSE\n",
"old_best_mse = -old_grid_search.best_score_\n",
"\n",
"\n",
"# Установка параметров для поиска по сетке для новых значений\n",
"new_param_grid: dict[str, list[int]] = {\n",
" 'n_estimators': [50],\n",
" 'max_depth': [5],\n",
" 'min_samples_split': [10]\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью поиска по сетке для новых параметров\n",
"new_grid_search = GridSearchCV(estimator=model, \n",
" param_grid=new_param_grid,\n",
" scoring='neg_mean_squared_error', cv=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"new_grid_search.fit(X_train_processing_result, y_train)\n",
"\n",
"# Результаты подбора для новых параметров\n",
"new_best_params = new_grid_search.best_params_\n",
"# Меняем знак, так как берем отрицательное значение MSE\n",
"new_best_mse = -new_grid_search.best_score_\n",
"\n",
"\n",
"# Обучение модели с лучшими параметрами для новых значений\n",
"model_best = RandomForestRegressor(**new_best_params)\n",
"model_best.fit(X_train_processing_result, y_train)\n",
"\n",
"# Прогнозирование на тестовой выборке\n",
"y_pred = model_best.predict(X_test_processing_result)\n",
"\n",
"# Оценка производительности модели\n",
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"\n",
"\n",
"# Вывод результатов\n",
"print(\"Старые параметры:\", old_best_params)\n",
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
"print(\"\\nНовые параметры:\", new_best_params)\n",
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)\n",
"\n",
"# Обучение модели с лучшими параметрами для старых значений\n",
"model_old = RandomForestRegressor(**old_best_params)\n",
"model_old.fit(X_train_processing_result, y_train)\n",
"\n",
"# Прогнозирование на тестовой выборке для старых параметров\n",
"y_pred_old = model_old.predict(X_test_processing_result)\n",
"\n",
"# Визуализация ошибок\n",
"plt.figure(figsize=(10, 5))\n",
"plt.plot(y_test, label='Реальные значения', marker='o', linestyle='-', color='black')\n",
"plt.plot(y_pred_old, label='Предсказанные значения (старые параметры)', marker='x', linestyle='--', color='blue')\n",
"plt.plot(y_pred, label='Предсказанные значения (новые параметры)', marker='s', linestyle='--', color='orange')\n",
"plt.xlabel('Объекты')\n",
"plt.ylabel('Значения')\n",
"plt.title('Сравнение реальных и предсказанных значений')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Задача классификации:\n",
"\n",
"Основные метрики для классификации:\n",
"- Accuracy (точность) показывает долю правильно классифицированных примеров среди всех наблюдений. Легко интерпретируется, но может быть недостаточно информативной для несбалансированных классов.\n",
"- F1-Score гармоническое среднее между точностью (precision) и полнотой (recall). Подходит для задач, где важно одновременно учитывать как ложные положительные, так и ложные отрицательные ошибки, особенно при несбалансированных классах.\n",
"- ROC AUC (Area Under the ROC Curve) отражает способность модели различать положительные и отрицательные классы на всех уровнях порога вероятности. Значение от 0.5 (случайное угадывание) до 1.0 (идеальная модель). Полезна для оценки модели на несбалансированных данных.\n",
"- Cohen's Kappa измеряет степень согласия между предсказаниями модели и истинными метками с учётом случайного угадывания. Значения варьируются от -1 (полное несогласие) до 1 (идеальное согласие). Удобна для оценки на несбалансированных данных.\n",
"- MCC (Matthews Correlation Coefficient) метрика корреляции между предсказаниями и истинными классами, учитывающая все типы ошибок (TP, TN, FP, FN). Значение варьируется от -1 (полная несоответствие) до 1 (идеальное совпадение). Отлично подходит для задач с несбалансированными классами.\n",
"- Confusion Matrix (матрица ошибок) матрица ошибок отражает распределение предсказаний модели по каждому из классов."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Разбиение данных"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Самый частый класс: Sale\n",
"Baseline Accuracy: 0.59375\n",
"Baseline F1: 0.4424019607843137\n"
]
}
],
"source": [
"from sklearn.metrics import accuracy_score, f1_score\n",
"\n",
"\n",
"# Определяем целевой признак и входные признаки\n",
"y_feature: str = 'Transaction'\n",
"X_features: list[str] = df.drop(columns=y_feature, axis=1).columns.tolist()\n",
"\n",
"# Разбиваем данные на обучающую и тестовую выборки\n",
"X_df_train, X_df_test, y_df_train, y_df_test = split_into_train_test(\n",
" df, \n",
" stratify_colname=y_feature, \n",
" frac_train=0.8, \n",
" random_state=42 \n",
")\n",
"\n",
"# Определяем самый частый класс\n",
"most_frequent_class = y_df_train.mode().values[0][0]\n",
"print(f\"Самый частый класс: {most_frequent_class}\")\n",
"\n",
"# Вычисляем предсказания базовой модели (все предсказания равны самому частому классу)\n",
"baseline_predictions: list[str] = [most_frequent_class] * len(y_df_test)\n",
"\n",
"# Оцениваем базовую модель\n",
"print('Baseline Accuracy:', accuracy_score(y_df_test, baseline_predictions))\n",
"print('Baseline F1:', f1_score(y_df_test, baseline_predictions, average='weighted'))\n",
"\n",
"# Унитарное кодирование для целевого признака\n",
"y_df_train = y_df_train['Transaction'].map({'Sale': 1, 'Option Exercise': 0})\n",
"y_df_test = y_df_test['Transaction'].map({'Sale': 1, 'Option Exercise': 0})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выбор моделей обучения:\n",
"\n",
"Для обучения были выбраны следующие модели:\n",
"1. **Случайный лес (Random Forest)**: Ансамблевая модель, которая использует множество решающих деревьев. Она хорошо справляется с нелинейными зависимостями и шумом в данных, а также обладает устойчивостью к переобучению.\n",
"2. **Логистическая регрессия (Logistic Regression)**: Статистический метод для бинарной классификации, который моделирует зависимость между целевой переменной и независимыми признаками, используя логистическую функцию. Она проста в интерпретации и быстра в обучении.\n",
"3. **Метод ближайших соседей (KNN)**: Алгоритм классификации, который предсказывает класс на основе ближайших k обучающих примеров. KNN интуитивно понятен и не требует обучения, но может быть медленным на больших данных и чувствительным к выбору параметров."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Конвейеры для обработки числовых и категориальных значений, а так же основной конвейер уже были построены ранее при решении задачи регрессии."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Cost</th>\n",
" <th>Shares</th>\n",
" <th>Value ($)</th>\n",
" <th>Shares Total</th>\n",
" <th>Year</th>\n",
" <th>Month</th>\n",
" <th>Day</th>\n",
" <th>Insider Trading_DENHOLM ROBYN M</th>\n",
" <th>Insider Trading_Kirkhorn Zachary</th>\n",
" <th>Insider Trading_Musk Elon</th>\n",
" <th>Insider Trading_Musk Kimbal</th>\n",
" <th>Insider Trading_Taneja Vaibhav</th>\n",
" <th>Insider Trading_Wilson-Thompson Kathleen</th>\n",
" <th>Relationship_Chief Accounting Officer</th>\n",
" <th>Relationship_Chief Financial Officer</th>\n",
" <th>Relationship_Director</th>\n",
" <th>Relationship_SVP Powertrain and Energy Eng.</th>\n",
" <th>Transaction_Sale</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-0.966516</td>\n",
" <td>-0.361759</td>\n",
" <td>-0.450022</td>\n",
" <td>-0.343599</td>\n",
" <td>0.715678</td>\n",
" <td>-0.506108</td>\n",
" <td>-0.400623</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>-1.074894</td>\n",
" <td>1.225216</td>\n",
" <td>-0.414725</td>\n",
" <td>-0.319938</td>\n",
" <td>-1.397276</td>\n",
" <td>0.801338</td>\n",
" <td>0.906673</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>-1.074894</td>\n",
" <td>1.211753</td>\n",
" <td>-0.415027</td>\n",
" <td>-0.320141</td>\n",
" <td>-1.397276</td>\n",
" <td>1.062828</td>\n",
" <td>-0.098939</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.167142</td>\n",
" <td>0.037499</td>\n",
" <td>1.023612</td>\n",
" <td>-0.325853</td>\n",
" <td>-1.397276</td>\n",
" <td>1.062828</td>\n",
" <td>-0.501184</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1.217886</td>\n",
" <td>-0.075287</td>\n",
" <td>0.632973</td>\n",
" <td>-0.330205</td>\n",
" <td>-1.397276</td>\n",
" <td>1.062828</td>\n",
" <td>-0.501184</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0.505872</td>\n",
" <td>-0.361021</td>\n",
" <td>-0.443679</td>\n",
" <td>-0.343698</td>\n",
" <td>0.715678</td>\n",
" <td>-0.767598</td>\n",
" <td>1.308918</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>-1.088674</td>\n",
" <td>-0.357532</td>\n",
" <td>-0.450389</td>\n",
" <td>-0.342863</td>\n",
" <td>0.715678</td>\n",
" <td>0.278360</td>\n",
" <td>-0.903429</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>-0.692146</td>\n",
" <td>-0.355855</td>\n",
" <td>-0.445383</td>\n",
" <td>-0.343220</td>\n",
" <td>0.715678</td>\n",
" <td>0.801338</td>\n",
" <td>1.409480</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>-1.088674</td>\n",
" <td>-0.361181</td>\n",
" <td>-0.450389</td>\n",
" <td>-0.343649</td>\n",
" <td>-1.397276</td>\n",
" <td>1.062828</td>\n",
" <td>-0.903429</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>1.091997</td>\n",
" <td>-0.204531</td>\n",
" <td>0.114712</td>\n",
" <td>1.538166</td>\n",
" <td>0.715678</td>\n",
" <td>-1.029087</td>\n",
" <td>1.208357</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Cost Shares Value ($) Shares Total Year Month Day \\\n",
"0 -0.966516 -0.361759 -0.450022 -0.343599 0.715678 -0.506108 -0.400623 \n",
"1 -1.074894 1.225216 -0.414725 -0.319938 -1.397276 0.801338 0.906673 \n",
"2 -1.074894 1.211753 -0.415027 -0.320141 -1.397276 1.062828 -0.098939 \n",
"3 1.167142 0.037499 1.023612 -0.325853 -1.397276 1.062828 -0.501184 \n",
"4 1.217886 -0.075287 0.632973 -0.330205 -1.397276 1.062828 -0.501184 \n",
"5 0.505872 -0.361021 -0.443679 -0.343698 0.715678 -0.767598 1.308918 \n",
"6 -1.088674 -0.357532 -0.450389 -0.342863 0.715678 0.278360 -0.903429 \n",
"7 -0.692146 -0.355855 -0.445383 -0.343220 0.715678 0.801338 1.409480 \n",
"8 -1.088674 -0.361181 -0.450389 -0.343649 -1.397276 1.062828 -0.903429 \n",
"9 1.091997 -0.204531 0.114712 1.538166 0.715678 -1.029087 1.208357 \n",
"\n",
" Insider Trading_DENHOLM ROBYN M Insider Trading_Kirkhorn Zachary \\\n",
"0 0.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 0.0 \n",
"4 0.0 0.0 \n",
"5 0.0 0.0 \n",
"6 0.0 0.0 \n",
"7 0.0 0.0 \n",
"8 0.0 0.0 \n",
"9 0.0 0.0 \n",
"\n",
" Insider Trading_Musk Elon Insider Trading_Musk Kimbal \\\n",
"0 0.0 0.0 \n",
"1 1.0 0.0 \n",
"2 1.0 0.0 \n",
"3 1.0 0.0 \n",
"4 1.0 0.0 \n",
"5 0.0 0.0 \n",
"6 0.0 0.0 \n",
"7 0.0 0.0 \n",
"8 0.0 0.0 \n",
"9 1.0 0.0 \n",
"\n",
" Insider Trading_Taneja Vaibhav Insider Trading_Wilson-Thompson Kathleen \\\n",
"0 1.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 0.0 \n",
"4 0.0 0.0 \n",
"5 0.0 0.0 \n",
"6 1.0 0.0 \n",
"7 0.0 0.0 \n",
"8 1.0 0.0 \n",
"9 0.0 0.0 \n",
"\n",
" Relationship_Chief Accounting Officer \\\n",
"0 1.0 \n",
"1 0.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"4 0.0 \n",
"5 0.0 \n",
"6 1.0 \n",
"7 0.0 \n",
"8 1.0 \n",
"9 0.0 \n",
"\n",
" Relationship_Chief Financial Officer Relationship_Director \\\n",
"0 0.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 0.0 \n",
"4 0.0 0.0 \n",
"5 0.0 0.0 \n",
"6 0.0 0.0 \n",
"7 0.0 0.0 \n",
"8 0.0 0.0 \n",
"9 0.0 0.0 \n",
"\n",
" Relationship_SVP Powertrain and Energy Eng. Transaction_Sale \n",
"0 0.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 1.0 \n",
"4 0.0 1.0 \n",
"5 1.0 1.0 \n",
"6 0.0 0.0 \n",
"7 1.0 1.0 \n",
"8 0.0 0.0 \n",
"9 0.0 1.0 "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Применение конвейера\n",
"preprocessing_result = pipeline_end.fit_transform(X_df_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df.head(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка результатов обучения:\n",
"1. **Случайный лес (Random Forest)**:\n",
" - Показатели:\n",
" - Precision (обучение): 1.0\n",
" - Precision (тест): 1.0\n",
" - Recall (обучение): 1.0\n",
" - Recall (тест): 1.0\n",
" - Accuracy (обучение): 1.0\n",
" - Accuracy (тест): 1.0\n",
" - F1 Score (обучение): 1.0\n",
" - F1 Score (тест): 1.0\n",
" - ROC AUC (тест): 1.0\n",
" - Cohen Kappa (тест): 1.0\n",
" - MCC (тест): 1.0\n",
" - Confusion Matrix (тест):\n",
" ```\n",
" [[13, 0],\n",
" [ 0, 19]]\n",
" ```\n",
" - Вывод: Случайный лес идеально справляется с задачей на обеих выборках. Однако столь высокие значения метрик на обучении и тесте могут указывать на переобучение модели.\n",
"2. **Логистическая регрессия (Logistic Regression)**:\n",
" - Показатели:\n",
" - Precision (обучение): 1.0\n",
" - Precision (тест): 1.0\n",
" - Recall (обучение): 1.0\n",
" - Recall (тест): 1.0\n",
" - Accuracy (обучение): 1.0\n",
" - Accuracy (тест): 1.0\n",
" - F1 Score (обучение): 1.0\n",
" - F1 Score (тест): 1.0\n",
" - ROC AUC (тест): 1.0\n",
" - Cohen Kappa (тест): 1.0\n",
" - MCC (тест): 1.0\n",
" - Confusion Matrix (тест):\n",
" ```\n",
" [[13, 0],\n",
" [ 0, 19]]\n",
" ```\n",
" - Вывод: Логистическая регрессия также показывает идеальные результаты. Это может быть связано с линейной разделимостью данных.\n",
"3. **Метод ближайших соседей (KNN)**:\n",
" - Показатели:\n",
" - Precision (обучение): 1.0\n",
" - Precision (тест): 1.0\n",
" - Recall (обучение): 0.95\n",
" - Recall (тест): 0.947\n",
" - Accuracy (обучение): 0.968\n",
" - Accuracy (тест): 0.969\n",
" - F1 Score (обучение): 0.974\n",
" - F1 Score (тест): 0.973\n",
" - ROC AUC (тест): 0.974\n",
" - Cohen Kappa (тест): 0.936\n",
" - MCC (тест): 0.938\n",
" - Confusion Matrix (тест):\n",
" ```\n",
" [[13, 0],\n",
" [ 1, 18]]\n",
" ```\n",
" - Вывод: Метод ближайших соседей показывает хорошие результаты, с небольшим снижением полноты на тестовой выборке. Это связано с особенностями алгоритма, который может быть чувствителен к выбросам и распределению данных."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Модель: RandomForestClassifier\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 1.0\n",
"\tRecall_test: 1.0\n",
"\tAccuracy_train: 1.0\n",
"\tAccuracy_test: 1.0\n",
"\tF1_train: 1.0\n",
"\tF1_test: 1.0\n",
"\tROC_AUC_test: 1.0\n",
"\tCohen_kappa_test: 1.0\n",
"\tMCC_test: 1.0\n",
"\tConfusion_matrix: [[13 0]\n",
" [ 0 19]]\n",
"\n",
"Модель: LogisticRegression\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 1.0\n",
"\tRecall_test: 1.0\n",
"\tAccuracy_train: 1.0\n",
"\tAccuracy_test: 1.0\n",
"\tF1_train: 1.0\n",
"\tF1_test: 1.0\n",
"\tROC_AUC_test: 1.0\n",
"\tCohen_kappa_test: 1.0\n",
"\tMCC_test: 1.0\n",
"\tConfusion_matrix: [[13 0]\n",
" [ 0 19]]\n",
"\n",
"Модель: KNN\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 0.95\n",
"\tRecall_test: 0.9473684210526315\n",
"\tAccuracy_train: 0.967741935483871\n",
"\tAccuracy_test: 0.96875\n",
"\tF1_train: 0.9743589743589743\n",
"\tF1_test: 0.972972972972973\n",
"\tROC_AUC_test: 0.9736842105263157\n",
"\tCohen_kappa_test: 0.9359999999999999\n",
"\tMCC_test: 0.9379228369755696\n",
"\tConfusion_matrix: [[13 0]\n",
" [ 1 18]]\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"e:\\aim\\aimenv\\Lib\\site-packages\\joblib\\externals\\loky\\backend\\context.py:136: UserWarning: Could not find the number of physical cores for the following reason:\n",
"[WinError 2] Не удается найти указанный файл\n",
"Returning the number of logical cores instead. You can silence this warning by setting LOKY_MAX_CPU_COUNT to the number of cores you want to use.\n",
" warnings.warn(\n",
" File \"e:\\aim\\aimenv\\Lib\\site-packages\\joblib\\externals\\loky\\backend\\context.py\", line 257, in _count_physical_cores\n",
" cpu_info = subprocess.run(\n",
" ^^^^^^^^^^^^^^^\n",
" File \"C:\\Users\\Владимир\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\subprocess.py\", line 548, in run\n",
" with Popen(*popenargs, **kwargs) as process:\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"C:\\Users\\Владимир\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\subprocess.py\", line 1026, in __init__\n",
" self._execute_child(args, executable, preexec_fn, close_fds,\n",
" File \"C:\\Users\\Владимир\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\subprocess.py\", line 1538, in _execute_child\n",
" hp, ht, pid, tid = _winapi.CreateProcess(executable, args,\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn import metrics\n",
"\n",
"\n",
"# Оценка качества различных моделей на основе метрик\n",
"def evaluate_models(models: dict[str, Any], \n",
" pipeline_end: Pipeline, \n",
" X_train: DataFrame, y_train, \n",
" X_test: DataFrame, y_test) -> dict[str, dict[str, Any]]:\n",
" results: dict[str, dict[str, Any]] = {}\n",
" \n",
" for model_name, model in models.items():\n",
" # Создание конвейера для текущей модели\n",
" model_pipeline = Pipeline(\n",
" [\n",
" (\"pipeline\", pipeline_end), \n",
" (\"model\", model),\n",
" ]\n",
" )\n",
" \n",
" # Обучение модели\n",
" model_pipeline.fit(X_train, y_train)\n",
" \n",
" # Предсказание для обучающей и тестовой выборки\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_predict = model_pipeline.predict(X_test)\n",
" \n",
" # Вычисление метрик для текущей модели\n",
" metrics_dict: dict[str, Any] = {\n",
" \"Precision_train\": metrics.precision_score(y_train, y_train_predict),\n",
" \"Precision_test\": metrics.precision_score(y_test, y_test_predict),\n",
" \"Recall_train\": metrics.recall_score(y_train, y_train_predict),\n",
" \"Recall_test\": metrics.recall_score(y_test, y_test_predict),\n",
" \"Accuracy_train\": metrics.accuracy_score(y_train, y_train_predict),\n",
" \"Accuracy_test\": metrics.accuracy_score(y_test, y_test_predict),\n",
" \"F1_train\": metrics.f1_score(y_train, y_train_predict),\n",
" \"F1_test\": metrics.f1_score(y_test, y_test_predict),\n",
" \"ROC_AUC_test\": metrics.roc_auc_score(y_test, y_test_predict),\n",
" \"Cohen_kappa_test\": metrics.cohen_kappa_score(y_test, y_test_predict),\n",
" \"MCC_test\": metrics.matthews_corrcoef(y_test, y_test_predict),\n",
" \"Confusion_matrix\": metrics.confusion_matrix(y_test, y_test_predict),\n",
" }\n",
" \n",
" # Сохранение результатов\n",
" results[model_name] = metrics_dict\n",
" \n",
" return results\n",
"\n",
"\n",
"# Выбранные модели для классификации\n",
"models_classification: dict[str, Any] = {\n",
" \"RandomForestClassifier\": RandomForestClassifier(random_state=42),\n",
" \"LogisticRegression\": LogisticRegression(max_iter=1000),\n",
" \"KNN\": KNeighborsClassifier(),\n",
"}\n",
"\n",
"results: dict[str, dict[str, Any]] = evaluate_models(models_classification,\n",
" pipeline_end,\n",
" X_df_train, y_df_train,\n",
" X_df_test, y_df_test)\n",
"\n",
"# Вывод результатов\n",
"for model_name, metrics_dict in results.items():\n",
" print(f\"Модель: {model_name}\")\n",
" for metric_name, value in metrics_dict.items():\n",
" print(f\"\\t{metric_name}: {value}\")\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Матрица ошибок:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABEoAAAQTCAYAAABzx8zfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAADcLUlEQVR4nOzdeZyN5f/H8feZYfaNDGMYY9/CiCLZQ0NRJAmFhCIhEZJdTWRpsf5aUKlEQpQ1+1IpQyRZs8s+xjbLuX9/zJnz7ZjFGXNus3g9H4/7kXPf97nOdY5p3rfPua7rthiGYQgAAAAAAAByy+oOAAAAAAAAZBcUSgAAAAAAAGwolAAAAAAAANhQKAEAAAAAALChUAIAAAAAAGBDoQQAAAAAAMCGQgkAAAAAAIBNnqzuAAAAyP6uX7+uuLg4l7bp4eEhLy8vl7YJAMDdgmw2D4USAACQruvXr6tEuJ9O/Zvo0nZDQkJ06NAhLsgAAMggstlcFEoAAEC64uLidOrfRP3zW3EF+Ltm1m7MZavCqx9WXFzcXX8xBgBARpHN5qJQAgAAnOLnb5Gfv8UlbVnlmnYAALibkc3mYDFXAAAAAAAAG0aUAAAApyQaViUarmsLAABkDtlsDgolAADAKVYZsso1V2OuagcAgLsZ2WwOpt4AAAAAAADYMKIEAAA4xSqrXDUo13UtAQBw9yKbzUGhBAAAOCXRMJRouGZYrqvaAQDgbkY2m4OpNwAAAAAAADYUSgDcUufOnVW8ePGs7sZd5fDhw7JYLJo1a1aW9aF48eLq3Lmzw759+/bpkUceUWBgoCwWixYuXKhZs2bJYrHo8OHDWdJP3DnJC8a5agNwZzVo0EANGjRwWXup5QScs3btWlksFq1duzaru4Icjmw2B4USIJtJ/kdn8pYnTx4VKVJEnTt31vHjx7O6e3dc586dHT6P/27Lli3L6u6lcOLECY0YMULR0dFpnrN27Vo9+eSTCgkJkYeHhwoWLKgWLVpowYIFd66jt6lTp076448/9NZbb+nzzz/X/fffn9VdAoAcKTnvt23bltVdSdfmzZs1YsQIXbx40SXtJX8RkLy5ubkpf/78atasmbZs2eKS1wCAzGKNEiCbGjVqlEqUKKHr169r69atmjVrljZu3Khdu3bJy8srq7t3R3l6eurjjz9OsT8iIiILepO+EydOaOTIkSpevLiqVq2a4vjw4cM1atQolSlTRi+++KLCw8N17tw5/fDDD2rdurXmzJmj9u3b3/mOp2Lv3r1yc/tfPf3atWvasmWLhgwZol69etn3P/fcc3rmmWfk6emZFd3EHWSVoURuQQjkWCtWrMjwczZv3qyRI0eqc+fOCgoKcjh2c05kRLt27fToo48qMTFRf//9t6ZOnaqGDRvq119/VeXKlW+rzZykXr16unbtmjw8PLK6K8jhyGZzUCgBsqlmzZrZv63v2rWrChQooLFjx2rx4sV6+umns7h3d1aePHn07LPPmtL21atX5ePjY0rbN5s/f75GjRqlp556Sl9++aXy5s1rPzZgwAAtX75c8fHxd6Qvzri58HHmzBlJSnGh7O7uLnd3d5e97pUrV+Tr6+uy9uA6rhyWy8UYcOe5+h/lmSmQV6tWzSHb69atq2bNmmnatGmaOnWqK7rntKzIHTc3t7vuiy+Yg2w2B1NvgByibt26kqQDBw5IkuLi4jRs2DBVr15dgYGB8vX1Vd26dbVmzRqH5yUPcR0/frz+7//+T6VKlZKnp6ceeOAB/frrryleZ+HChapUqZK8vLxUqVIlfffdd6n258qVK3rttdcUFhYmT09PlStXTuPHj5dx02rZFotFvXr10rx581SxYkV5e3urVq1a+uOPPyRJM2bMUOnSpeXl5aUGDRrc9joXU6dO1b333itPT0+Fhobq5ZdfTjFMuEGDBqpUqZJ+++031atXTz4+PnrjjTckSTdu3NDw4cNVunRpeXp6KiwsTK+//rpu3Ljh0MbKlStVp04dBQUFyc/PT+XKlbO3sXbtWj3wwAOSpOeff94+rDh5nZGhQ4cqf/78+vTTTx2KJMkiIyPVvHnzNN/jzp071blzZ5UsWVJeXl4KCQlRly5ddO7cOYfzLl++rL59+6p48eLy9PRUwYIF1aRJE/3+++/2c/bt26fWrVsrJCREXl5eKlq0qJ555hldunTJfs5/556PGDFC4eHhkpKKOhaLxb5uTVprlPz444+qW7eufH195e/vr8cee0y7d+92OKdz587y8/PTgQMH9Oijj8rf318dOnRI8zMAgLvJ9u3b1axZMwUEBMjPz0+NGjXS1q1bU5y3c+dO1a9fX97e3ipatKjGjBmjmTNnpvjdnNoaJR9++KHuvfde+fj4KF++fLr//vv15ZdfSkr63T9gwABJUokSJey5ltxmamuUXLx4Ua+++qo9g4oWLaqOHTvq7Nmz6b7Xm69z/tte37597dcbpUuX1tixY2W1Ot7G9Ny5c3ruuecUEBCgoKAgderUSTt27Eix3ld6uWO1WvXee+/p3nvvlZeXlwoVKqQXX3xRFy5ccHitbdu2KTIyUgUKFJC3t7dKlCihLl26OJzz9ddfq3r16vL391dAQIAqV66s999/3348rTVK5s2bp+rVq8vb21sFChTQs88+m2LqdfJ7OH78uFq2bCk/Pz8FBwerf//+SkxMTPdzBuAcRpQAOUTyRUm+fPkkSTExMfr444/Vrl07devWTZcvX9Ynn3yiyMhI/fLLLymmfXz55Ze6fPmyXnzxRVksFo0bN05PPvmkDh48aP9H+4oVK9S6dWtVrFhRUVFROnfunJ5//nkVLVrUoS3DMPT4449rzZo1euGFF1S1alUtX75cAwYM0PHjxzVp0iSH8zds2KDFixfr5ZdfliRFRUWpefPmev311zV16lT17NlTFy5c0Lhx49SlSxf99NNPKd7/zRdYefPmVWBgoKSkC7mRI0eqcePG6tGjh/bu3atp06bp119/1aZNmxyKEufOnVOzZs30zDPP6Nlnn1WhQoVktVr1+OOPa+PGjerevbsqVKigP/74Q5MmTdLff/+thQsXSpJ2796t5s2bq0qVKho1apQ8PT21f/9+bdq0SZJUoUIFjRo1SsOGDVP37t3tF30PPfSQ9u3bp7/++ktdunSRv7+/U3/nN1u5cqUOHjyo559/XiEhIdq9e7f+7//+T7t379bWrVtlsVgkSS+99JLmz5+vXr16qWLFijp37pw2btyoPXv2qFq1aoqLi1NkZKRu3LihV155RSEhITp+/LiWLFmiixcv2j/X/3ryyScVFBSkV1991T5c2s/PL82+fv755+rUqZMiIyM1duxYXb16VdOmTVOdOnW0fft2h8WBExISFBkZqTp16mj8+PF3bIQPMo5bEAJ3zu7du1W3bl0FBATo9ddfV968eTVjxgw1aNBA69atU82aNSVJx48fV8OGDWWxWDR48GD5+vrq448/dmq0x0cffaTevXvrqaeeUp8+fXT9+nXt3LlTP//8s9q3b68nn3xSf//9t7766itNmjRJBQoUkCQFBwen2l5sbKzq1q2rPXv2qEuXLqpWrZrOnj2rxYsX69ixY/bnp+bm6xwpadRn/fr1dfz4cb344osqVqyYNm/erMGDB+vkyZN67733JCUVOFq0aKFffvlFPXr0UPny5bVo0SJ16tQp1ddKK3defPFFzZo1S88//7x69+6tQ4cOafLkydq+fbv9euLff//VI488ouDgYA0aNEhBQUE6fPiwwzpjK1euVLt27dSoUSONHTtWkrRnzx5t2rRJffr0SfMzSH7tBx54QFFRUTp9+rTef/99bdq0Sdu3b3cY0ZmYmKjIyEjVrFlT48eP16pVqzRhwgSVKlVKPXr0SPM1kPuQzSYxAGQrM2fONCQZq1atMs6cOWMcPXrUmD9/vhEcHGx4enoaR48eNQzDMBISEowbN244PPfChQtGoUKFjC5dutj3HTp0yJBk3HPPPcb58+ft+xctWmRIMr7//nv7vqpVqxqFCxc2Ll68aN+3YsUKQ5IRHh5u37dw4UJDkjFmzBiH13/qqac
"text/plain": [
"<Figure size 1200x1000 with 7 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"\n",
"_, ax = plt.subplots(ceil(len(models_classification) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"\n",
"for index, key in enumerate(models_classification.keys()):\n",
" c_matrix = results[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Sale\", \"Option Exercise\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подбор гиперпараметров:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие параметры: {'model__criterion': 'gini', 'model__max_depth': 5, 'model__max_features': 'sqrt', 'model__n_estimators': 10}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"e:\\aim\\aimenv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
}
],
"source": [
"# Создание конвейера\n",
"pipeline = Pipeline([\n",
" (\"processing\", pipeline_end),\n",
" (\"model\", RandomForestClassifier(random_state=42))\n",
"])\n",
"\n",
"# Установка параметров для поиска по сетке\n",
"param_grid: dict[str, Any] = {\n",
" \"model__n_estimators\": [10, 50, 100],\n",
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
" \"model__max_depth\": [5, 7, 10],\n",
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью поиска по сетке\n",
"grid_search = GridSearchCV(estimator=pipeline, \n",
" param_grid=param_grid,\n",
" n_jobs=-1)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"grid_search.fit(X_df_train, y_df_train)\n",
"\n",
"# Результаты подбора гиперпараметров\n",
"print(\"Лучшие параметры:\", grid_search.best_params_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Сравнение наборов гиперпараметров:\n",
"\n",
"Результаты анализа показывают, что как стоковая модель, так и оптимизированная модель демонстрируют идентичные показатели качества, включая абсолютные значения всех ключевых метрик (Precision, Recall, Accuracy, F1-Score и другие), равные 1.0 на обеих выборках (обучающей и тестовой). Это указывает на то, что обе модели идеально справляются с задачей классификации."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Стоковая модель:\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 1.0\n",
"\tRecall_test: 1.0\n",
"\tAccuracy_train: 1.0\n",
"\tAccuracy_test: 1.0\n",
"\tF1_train: 1.0\n",
"\tF1_test: 1.0\n",
"\tROC_AUC_test: 1.0\n",
"\tCohen_kappa_test: 1.0\n",
"\tMCC_test: 1.0\n",
"\tConfusion_matrix: [[13 0]\n",
" [ 0 19]]\n",
"\n",
"Оптимизированная модель:\n",
"\tPrecision_train: 1.0\n",
"\tPrecision_test: 1.0\n",
"\tRecall_train: 1.0\n",
"\tRecall_test: 1.0\n",
"\tAccuracy_train: 1.0\n",
"\tAccuracy_test: 1.0\n",
"\tF1_train: 1.0\n",
"\tF1_test: 1.0\n",
"\tROC_AUC_test: 1.0\n",
"\tCohen_kappa_test: 1.0\n",
"\tMCC_test: 1.0\n",
"\tConfusion_matrix: [[13 0]\n",
" [ 0 19]]\n"
]
}
],
"source": [
"# Обучение модели со старыми гипермараметрами\n",
"pipeline.fit(X_df_train, y_df_train)\n",
"\n",
"# Предсказание для обучающей и тестовой выборки\n",
"y_train_predict = pipeline.predict(X_df_train)\n",
"y_test_predict = pipeline.predict(X_df_test)\n",
" \n",
"# Вычисление метрик для модели со старыми гипермараметрами\n",
"base_model_metrics: dict[str, Any] = {\n",
" \"Precision_train\": metrics.precision_score(y_df_train, y_train_predict),\n",
" \"Precision_test\": metrics.precision_score(y_df_test, y_test_predict),\n",
" \"Recall_train\": metrics.recall_score(y_df_train, y_train_predict),\n",
" \"Recall_test\": metrics.recall_score(y_df_test, y_test_predict),\n",
" \"Accuracy_train\": metrics.accuracy_score(y_df_train, y_train_predict),\n",
" \"Accuracy_test\": metrics.accuracy_score(y_df_test, y_test_predict),\n",
" \"F1_train\": metrics.f1_score(y_df_train, y_train_predict),\n",
" \"F1_test\": metrics.f1_score(y_df_test, y_test_predict),\n",
" \"ROC_AUC_test\": metrics.roc_auc_score(y_df_test, y_test_predict),\n",
" \"Cohen_kappa_test\": metrics.cohen_kappa_score(y_df_test, y_test_predict),\n",
" \"MCC_test\": metrics.matthews_corrcoef(y_df_test, y_test_predict),\n",
" \"Confusion_matrix\": metrics.confusion_matrix(y_df_test, y_test_predict),\n",
"}\n",
"\n",
"# Модель с новыми гипермараметрами\n",
"optimized_model = RandomForestClassifier(\n",
" random_state=42,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"sqrt\",\n",
" n_estimators=10,\n",
")\n",
"\n",
"# Создание конвейера для модели с новыми гипермараметрами\n",
"optimized_model_pipeline = Pipeline(\n",
" [\n",
" (\"pipeline\", pipeline_end), \n",
" (\"model\", optimized_model),\n",
" ]\n",
")\n",
" \n",
"# Обучение модели с новыми гипермараметрами\n",
"optimized_model_pipeline.fit(X_df_train, y_df_train)\n",
" \n",
"# Предсказание для обучающей и тестовой выборки\n",
"y_train_predict = optimized_model_pipeline.predict(X_df_train)\n",
"y_test_predict = optimized_model_pipeline.predict(X_df_test)\n",
" \n",
"# Вычисление метрик для модели с новыми гипермараметрами\n",
"optimized_model_metrics: dict[str, Any] = {\n",
" \"Precision_train\": metrics.precision_score(y_df_train, y_train_predict),\n",
" \"Precision_test\": metrics.precision_score(y_df_test, y_test_predict),\n",
" \"Recall_train\": metrics.recall_score(y_df_train, y_train_predict),\n",
" \"Recall_test\": metrics.recall_score(y_df_test, y_test_predict),\n",
" \"Accuracy_train\": metrics.accuracy_score(y_df_train, y_train_predict),\n",
" \"Accuracy_test\": metrics.accuracy_score(y_df_test, y_test_predict),\n",
" \"F1_train\": metrics.f1_score(y_df_train, y_train_predict),\n",
" \"F1_test\": metrics.f1_score(y_df_test, y_test_predict),\n",
" \"ROC_AUC_test\": metrics.roc_auc_score(y_df_test, y_test_predict),\n",
" \"Cohen_kappa_test\": metrics.cohen_kappa_score(y_df_test, y_test_predict),\n",
" \"MCC_test\": metrics.matthews_corrcoef(y_df_test, y_test_predict),\n",
" \"Confusion_matrix\": metrics.confusion_matrix(y_df_test, y_test_predict),\n",
"}\n",
"\n",
"# Вывод информации\n",
"print('Стоковая модель:')\n",
"for metric_name, value in base_model_metrics.items():\n",
" print(f\"\\t{metric_name}: {value}\")\n",
"\n",
"print('\\nОптимизированная модель:')\n",
"for metric_name, value in optimized_model_metrics.items():\n",
" print(f\"\\t{metric_name}: {value}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "aimenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}