2332 lines
297 KiB
Plaintext
2332 lines
297 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Датасет: [Tesla Insider Trading](https://www.kaggle.com/datasets/ilyaryabov/tesla-insider-trading).\n",
|
|||
|
"\n",
|
|||
|
"### Описание датасета:\n",
|
|||
|
"Датасет представляет собой выборку операций с ценными бумагами компании Tesla, совершённых инсайдерами, и является частью более крупного проекта \"Insider Trading S&P500 – Inside Info\". Данные охватывают транзакции с участием крупных акционеров и должностных лиц компании, включая такие операции, как покупка, продажа и опционы, начиная с 10 ноября 2021 года и до 27 июля 2022 года.\n",
|
|||
|
"\n",
|
|||
|
"---\n",
|
|||
|
"\n",
|
|||
|
"### Анализ сведений:\n",
|
|||
|
"**Проблемная область:**\n",
|
|||
|
"Проблемная область данного датасета касается анализа инсайдерских сделок в публичных компаниях, а также их влияния на ценообразование акций. Инсайдерские транзакции, совершаемые людьми с доступом к непубличной информации (такими как руководители, крупные акционеры или члены совета директоров), могут быть индикаторами будущих изменений стоимости акций. Исследование таких транзакций помогает понять, как информация внутри компании отражается в действиях ключевых участников, и может выявить паттерны поведения, которые влияют на рынки.\n",
|
|||
|
"\n",
|
|||
|
"**Актуальность:**\n",
|
|||
|
"Анализ инсайдерских сделок становится особенно важным в условиях высокой волатильности рынка и неопределенности. Инвесторы, аналитики и компании используют такие данные, чтобы лучше понимать сигналы от крупных акционеров и должностных лиц. Действия инсайдеров, такие как покупки и продажи акций, нередко рассматриваются как индикаторы доверия к компании, что может оказывать значительное влияние на рыночные ожидания и прогнозы.\n",
|
|||
|
"\n",
|
|||
|
"**Объекты наблюдений:**\n",
|
|||
|
"Объектами наблюдений в датасете являются инсайдеры компании Tesla — лица, имеющие значительное влияние на управление и информацию компании. Каждый объект характеризуется различными параметрами, включая должность, тип транзакции, количество акций и общую стоимость сделок.\n",
|
|||
|
"\n",
|
|||
|
"**Атрибуты объектов:**\n",
|
|||
|
"- Insider Trading: ФИО лица, совершившего транзакцию.\n",
|
|||
|
"- Relationship: Должность или статус данного лица в компании Tesla.\n",
|
|||
|
"- Date: Дата завершения транзакции.\n",
|
|||
|
"- Transaction: Тип транзакции.\n",
|
|||
|
"- Cost: Цена одной акции на момент совершения транзакции.\n",
|
|||
|
"- Shares: Количество акций, участвующих в транзакции.\n",
|
|||
|
"- Value ($): Общая стоимость транзакции в долларах США.\n",
|
|||
|
"- Shares Total: Общее количество акций, принадлежащих этому лицу после завершения данной транзакции.\n",
|
|||
|
"- SEC Form 4: Дата записи транзакции в форме SEC Form 4, обязательной для отчётности о сделках инсайдеров.\n",
|
|||
|
"\n",
|
|||
|
"---\n",
|
|||
|
"\n",
|
|||
|
"### Бизнес-цели:\n",
|
|||
|
"1. **Для решения задачи регрессии:**\n",
|
|||
|
"Предсказать будущую стоимость акций компании Tesla на основе инсайдерских транзакций. Стоимость акций (\"Cost\") зависит от множества факторов, включая объём и тип транзакций, совершаемых инсайдерами. Если выявить зависимости между параметрами транзакций (количество акций, общий объём сделки, должность инсайдера) и стоимостью акций, это может помочь инвесторам принимать обоснованные решения о покупке или продаже.\n",
|
|||
|
"2. **Для решения задачи классификации:**\n",
|
|||
|
"Классифицировать тип инсайдерской транзакции (продажа акций или исполнение опционов) на основе характеристик сделки. Тип транзакции (\"Transaction\") может быть индикатором доверия инсайдера к текущей рыночной цене или будущей прибыльности компании. Модель, которая предсказывает тип транзакции, может помочь в оценке поведения инсайдеров и выявлении аномалий.\n",
|
|||
|
"\n",
|
|||
|
"---"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Выгрузка данных из файла в DataFrame:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 379,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from typing import Any, Tuple\n",
|
|||
|
"from math import ceil\n",
|
|||
|
"\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"from pandas import DataFrame\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"df: DataFrame = pd.read_csv('..//static//csv//TSLA.csv')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Краткая информация о DataFrame:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 380,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"<class 'pandas.core.frame.DataFrame'>\n",
|
|||
|
"RangeIndex: 156 entries, 0 to 155\n",
|
|||
|
"Data columns (total 9 columns):\n",
|
|||
|
" # Column Non-Null Count Dtype \n",
|
|||
|
"--- ------ -------------- ----- \n",
|
|||
|
" 0 Insider Trading 156 non-null object \n",
|
|||
|
" 1 Relationship 156 non-null object \n",
|
|||
|
" 2 Date 156 non-null object \n",
|
|||
|
" 3 Transaction 156 non-null object \n",
|
|||
|
" 4 Cost 156 non-null float64\n",
|
|||
|
" 5 Shares 156 non-null object \n",
|
|||
|
" 6 Value ($) 156 non-null object \n",
|
|||
|
" 7 Shares Total 156 non-null object \n",
|
|||
|
" 8 SEC Form 4 156 non-null object \n",
|
|||
|
"dtypes: float64(1), object(8)\n",
|
|||
|
"memory usage: 11.1+ KB\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>count</th>\n",
|
|||
|
" <th>mean</th>\n",
|
|||
|
" <th>std</th>\n",
|
|||
|
" <th>min</th>\n",
|
|||
|
" <th>25%</th>\n",
|
|||
|
" <th>50%</th>\n",
|
|||
|
" <th>75%</th>\n",
|
|||
|
" <th>max</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>Cost</th>\n",
|
|||
|
" <td>156.0</td>\n",
|
|||
|
" <td>478.785641</td>\n",
|
|||
|
" <td>448.922903</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>50.5225</td>\n",
|
|||
|
" <td>240.225</td>\n",
|
|||
|
" <td>934.1075</td>\n",
|
|||
|
" <td>1171.04</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" count mean std min 25% 50% 75% max\n",
|
|||
|
"Cost 156.0 478.785641 448.922903 0.0 50.5225 240.225 934.1075 1171.04"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 380,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Краткая информация о DataFrame\n",
|
|||
|
"df.info()\n",
|
|||
|
"\n",
|
|||
|
"# Статистическое описание числовых столбцов\n",
|
|||
|
"df.describe().transpose()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Предобработка данных:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 381,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Выборка данных:\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Insider Trading</th>\n",
|
|||
|
" <th>Relationship</th>\n",
|
|||
|
" <th>Transaction</th>\n",
|
|||
|
" <th>Cost</th>\n",
|
|||
|
" <th>Shares</th>\n",
|
|||
|
" <th>Value ($)</th>\n",
|
|||
|
" <th>Shares Total</th>\n",
|
|||
|
" <th>Year</th>\n",
|
|||
|
" <th>Month</th>\n",
|
|||
|
" <th>Day</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>Kirkhorn Zachary</td>\n",
|
|||
|
" <td>Chief Financial Officer</td>\n",
|
|||
|
" <td>Sale</td>\n",
|
|||
|
" <td>196.72</td>\n",
|
|||
|
" <td>10455</td>\n",
|
|||
|
" <td>2056775</td>\n",
|
|||
|
" <td>203073</td>\n",
|
|||
|
" <td>2022</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>6</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>Taneja Vaibhav</td>\n",
|
|||
|
" <td>Chief Accounting Officer</td>\n",
|
|||
|
" <td>Sale</td>\n",
|
|||
|
" <td>195.79</td>\n",
|
|||
|
" <td>2466</td>\n",
|
|||
|
" <td>482718</td>\n",
|
|||
|
" <td>100458</td>\n",
|
|||
|
" <td>2022</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>6</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>Baglino Andrew D</td>\n",
|
|||
|
" <td>SVP Powertrain and Energy Eng.</td>\n",
|
|||
|
" <td>Sale</td>\n",
|
|||
|
" <td>195.79</td>\n",
|
|||
|
" <td>1298</td>\n",
|
|||
|
" <td>254232</td>\n",
|
|||
|
" <td>65547</td>\n",
|
|||
|
" <td>2022</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>6</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>Taneja Vaibhav</td>\n",
|
|||
|
" <td>Chief Accounting Officer</td>\n",
|
|||
|
" <td>Option Exercise</td>\n",
|
|||
|
" <td>0.00</td>\n",
|
|||
|
" <td>7138</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>102923</td>\n",
|
|||
|
" <td>2022</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>5</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>Baglino Andrew D</td>\n",
|
|||
|
" <td>SVP Powertrain and Energy Eng.</td>\n",
|
|||
|
" <td>Option Exercise</td>\n",
|
|||
|
" <td>0.00</td>\n",
|
|||
|
" <td>2586</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>66845</td>\n",
|
|||
|
" <td>2022</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>5</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>5</th>\n",
|
|||
|
" <td>Kirkhorn Zachary</td>\n",
|
|||
|
" <td>Chief Financial Officer</td>\n",
|
|||
|
" <td>Option Exercise</td>\n",
|
|||
|
" <td>0.00</td>\n",
|
|||
|
" <td>16867</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>213528</td>\n",
|
|||
|
" <td>2022</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>5</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>6</th>\n",
|
|||
|
" <td>Baglino Andrew D</td>\n",
|
|||
|
" <td>SVP Powertrain and Energy Eng.</td>\n",
|
|||
|
" <td>Option Exercise</td>\n",
|
|||
|
" <td>20.91</td>\n",
|
|||
|
" <td>10500</td>\n",
|
|||
|
" <td>219555</td>\n",
|
|||
|
" <td>74759</td>\n",
|
|||
|
" <td>2022</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>27</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>7</th>\n",
|
|||
|
" <td>Baglino Andrew D</td>\n",
|
|||
|
" <td>SVP Powertrain and Energy Eng.</td>\n",
|
|||
|
" <td>Sale</td>\n",
|
|||
|
" <td>202.00</td>\n",
|
|||
|
" <td>10500</td>\n",
|
|||
|
" <td>2121000</td>\n",
|
|||
|
" <td>64259</td>\n",
|
|||
|
" <td>2022</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>27</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>8</th>\n",
|
|||
|
" <td>Kirkhorn Zachary</td>\n",
|
|||
|
" <td>Chief Financial Officer</td>\n",
|
|||
|
" <td>Sale</td>\n",
|
|||
|
" <td>193.00</td>\n",
|
|||
|
" <td>3750</td>\n",
|
|||
|
" <td>723750</td>\n",
|
|||
|
" <td>196661</td>\n",
|
|||
|
" <td>2022</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>6</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>9</th>\n",
|
|||
|
" <td>Baglino Andrew D</td>\n",
|
|||
|
" <td>SVP Powertrain and Energy Eng.</td>\n",
|
|||
|
" <td>Option Exercise</td>\n",
|
|||
|
" <td>20.91</td>\n",
|
|||
|
" <td>10500</td>\n",
|
|||
|
" <td>219555</td>\n",
|
|||
|
" <td>74759</td>\n",
|
|||
|
" <td>2022</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>27</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Insider Trading Relationship Transaction Cost \\\n",
|
|||
|
"0 Kirkhorn Zachary Chief Financial Officer Sale 196.72 \n",
|
|||
|
"1 Taneja Vaibhav Chief Accounting Officer Sale 195.79 \n",
|
|||
|
"2 Baglino Andrew D SVP Powertrain and Energy Eng. Sale 195.79 \n",
|
|||
|
"3 Taneja Vaibhav Chief Accounting Officer Option Exercise 0.00 \n",
|
|||
|
"4 Baglino Andrew D SVP Powertrain and Energy Eng. Option Exercise 0.00 \n",
|
|||
|
"5 Kirkhorn Zachary Chief Financial Officer Option Exercise 0.00 \n",
|
|||
|
"6 Baglino Andrew D SVP Powertrain and Energy Eng. Option Exercise 20.91 \n",
|
|||
|
"7 Baglino Andrew D SVP Powertrain and Energy Eng. Sale 202.00 \n",
|
|||
|
"8 Kirkhorn Zachary Chief Financial Officer Sale 193.00 \n",
|
|||
|
"9 Baglino Andrew D SVP Powertrain and Energy Eng. Option Exercise 20.91 \n",
|
|||
|
"\n",
|
|||
|
" Shares Value ($) Shares Total Year Month Day \n",
|
|||
|
"0 10455 2056775 203073 2022 3 6 \n",
|
|||
|
"1 2466 482718 100458 2022 3 6 \n",
|
|||
|
"2 1298 254232 65547 2022 3 6 \n",
|
|||
|
"3 7138 0 102923 2022 3 5 \n",
|
|||
|
"4 2586 0 66845 2022 3 5 \n",
|
|||
|
"5 16867 0 213528 2022 3 5 \n",
|
|||
|
"6 10500 219555 74759 2022 2 27 \n",
|
|||
|
"7 10500 2121000 64259 2022 2 27 \n",
|
|||
|
"8 3750 723750 196661 2022 2 6 \n",
|
|||
|
"9 10500 219555 74759 2022 1 27 "
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 381,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Преобразование типов данных\n",
|
|||
|
"df['Insider Trading'] = df['Insider Trading'].astype('category') # Преобразование в категорию\n",
|
|||
|
"df['Relationship'] = df['Relationship'].astype('category') # Преобразование в категорию\n",
|
|||
|
"df['Transaction'] = df['Transaction'].astype('category') # Преобразование в категорию\n",
|
|||
|
"df['Cost'] = pd.to_numeric(df['Cost'], errors='coerce') # Преобразование в float\n",
|
|||
|
"df['Shares'] = pd.to_numeric(df['Shares'].str.replace(',', ''), errors='coerce') # Преобразование в float с удалением запятых\n",
|
|||
|
"df['Value ($)'] = pd.to_numeric(df['Value ($)'].str.replace(',', ''), errors='coerce') # Преобразование в float с удалением запятых\n",
|
|||
|
"df['Shares Total'] = pd.to_numeric(df['Shares Total'].str.replace(',', ''), errors='coerce') # Преобразование в float с удалением запятых\n",
|
|||
|
"\n",
|
|||
|
"df['Date'] = pd.to_datetime(df['Date'], errors='coerce') # Преобразование в datetime\n",
|
|||
|
"df['Year'] = df['Date'].dt.year # Год\n",
|
|||
|
"df['Month'] = df['Date'].dt.month # Месяц\n",
|
|||
|
"df['Day'] = df['Date'].dt.day # День\n",
|
|||
|
"df: DataFrame = df.drop(columns=['Date', 'SEC Form 4']) # Удаление столбцов с датами\n",
|
|||
|
"\n",
|
|||
|
"print('Выборка данных:')\n",
|
|||
|
"df.head(10)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Бизнес-цель №1 (Задача регрессии).\n",
|
|||
|
"\n",
|
|||
|
"### Достижимый уровень качества модели:\n",
|
|||
|
"**Основные метрики для регрессии:**\n",
|
|||
|
"- **Средняя абсолютная ошибка (Mean Absolute Error, MAE)** – показывает среднее абсолютное отклонение между предсказанными и фактическими значениями.\n",
|
|||
|
"Легко интерпретируется, особенно в финансовых данных, где каждая ошибка в долларах имеет значение.\n",
|
|||
|
"- **Среднеквадратичная ошибка (Mean Squared Error, MSE)** – показывает, насколько отклоняются прогнозы модели от истинных значений в квадрате. Подходит для оценки общего качества модели.\n",
|
|||
|
"- **Коэффициент детерминации (R²)** – указывает, какую долю дисперсии зависимой переменной объясняет модель. R² варьируется от 0 до 1 (чем ближе к 1, тем лучше).\n",
|
|||
|
"\n",
|
|||
|
"---\n",
|
|||
|
"\n",
|
|||
|
"### Выбор ориентира:\n",
|
|||
|
"В качестве базовой модели для оценки качества предсказаний выбрано использование среднего значения целевой переменной (Cost) на обучающей выборке. Это простой и интуитивно понятный метод, который служит минимальным ориентиром для сравнения с более сложными моделями. Базовая модель помогает установить начальный уровень ошибок (MAE, MSE) и показатель качества (R²), которые сложные модели должны улучшить, чтобы оправдать своё использование.\n",
|
|||
|
"\n",
|
|||
|
"---"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Разбиение данных:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 382,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Baseline MAE: 417.78235887096776\n",
|
|||
|
"Baseline MSE: 182476.07973024843\n",
|
|||
|
"Baseline R²: -0.027074997920953914\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from pandas.core.frame import DataFrame\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Разбить данные на обучающую и тестовую выборки\n",
|
|||
|
"def split_into_train_test(\n",
|
|||
|
" df_input: DataFrame,\n",
|
|||
|
" stratify_colname: str = \"y\", \n",
|
|||
|
" frac_train: float = 0.8,\n",
|
|||
|
" random_state: int = 42,\n",
|
|||
|
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
|
|||
|
"\n",
|
|||
|
" if stratify_colname not in df_input.columns:\n",
|
|||
|
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
|
|||
|
" \n",
|
|||
|
" if not (0 < frac_train < 1):\n",
|
|||
|
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
|
|||
|
" \n",
|
|||
|
" X: DataFrame = df_input # Contains all columns.\n",
|
|||
|
" y: DataFrame = df_input[\n",
|
|||
|
" [stratify_colname]\n",
|
|||
|
" ] # Dataframe of just the column on which to stratify.\n",
|
|||
|
"\n",
|
|||
|
" # Split original dataframe into train and test dataframes.\n",
|
|||
|
" X_train, X_test, y_train, y_test = train_test_split(\n",
|
|||
|
" X, y,\n",
|
|||
|
" test_size=(1.0 - frac_train),\n",
|
|||
|
" random_state=random_state\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" return X_train, X_test, y_train, y_test\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Определяем целевой признак и входные признаки\n",
|
|||
|
"y_feature: str = 'Cost'\n",
|
|||
|
"X_features: list[str] = df.drop(columns=y_feature, axis=1).columns.tolist()\n",
|
|||
|
"\n",
|
|||
|
"# Разбиваем данные на обучающую и тестовую выборки\n",
|
|||
|
"X_df_train, X_df_test, y_df_train, y_df_test = split_into_train_test(\n",
|
|||
|
" df, \n",
|
|||
|
" stratify_colname=y_feature, \n",
|
|||
|
" frac_train=0.8, \n",
|
|||
|
" random_state=42 \n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Вычисляем предсказания базовой модели (среднее значение целевой переменной)\n",
|
|||
|
"baseline_predictions: list[float] = [y_df_train.mean()] * len(y_df_test) # type: ignore\n",
|
|||
|
"\n",
|
|||
|
"# Оцениваем базовую модель\n",
|
|||
|
"print('Baseline MAE:', mean_absolute_error(y_df_test, baseline_predictions))\n",
|
|||
|
"print('Baseline MSE:', mean_squared_error(y_df_test, baseline_predictions))\n",
|
|||
|
"print('Baseline R²:', r2_score(y_df_test, baseline_predictions))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Выбор моделей обучения:\n",
|
|||
|
"\n",
|
|||
|
"Для обучения были выбраны следующие модели:\n",
|
|||
|
"1. **Случайный лес (Random Forest)**: Ансамблевая модель, которая использует множество решающих деревьев. Она хорошо справляется с нелинейными зависимостями и шумом в данных, а также обладает устойчивостью к переобучению.\n",
|
|||
|
"2. **Линейная регрессия (Linear Regression)**: Простая модель, предполагающая линейную зависимость между признаками и целевой переменной. Она быстро обучается и предоставляет легкую интерпретацию результатов.\n",
|
|||
|
"3. **Градиентный бустинг (Gradient Boosting)**: Мощная модель, создающая ансамбль деревьев, которые корректируют ошибки предыдущих. Эта модель эффективна для сложных наборов данных и обеспечивает высокую точность предсказаний.\n",
|
|||
|
"\n",
|
|||
|
"---"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Построение конвейера:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 383,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.impute import SimpleImputer\n",
|
|||
|
"from sklearn.discriminant_analysis import StandardScaler\n",
|
|||
|
"from sklearn.preprocessing import OneHotEncoder\n",
|
|||
|
"from sklearn.compose import ColumnTransformer\n",
|
|||
|
"from sklearn.pipeline import Pipeline\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Числовые столбцы\n",
|
|||
|
"num_columns: list[str] = [\n",
|
|||
|
" column\n",
|
|||
|
" for column in df.columns\n",
|
|||
|
" if df[column].dtype not in (\"category\", \"object\")\n",
|
|||
|
"]\n",
|
|||
|
"\n",
|
|||
|
"# Категориальные столбцы\n",
|
|||
|
"cat_columns: list[str] = [\n",
|
|||
|
" column\n",
|
|||
|
" for column in df.columns\n",
|
|||
|
" if df[column].dtype in (\"category\", \"object\")\n",
|
|||
|
"]\n",
|
|||
|
"\n",
|
|||
|
"# Заполнение пропущенных значений\n",
|
|||
|
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
|||
|
"# Стандартизация\n",
|
|||
|
"num_scaler = StandardScaler()\n",
|
|||
|
"# Конвейер для обработки числовых данных\n",
|
|||
|
"preprocessing_num = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", num_imputer),\n",
|
|||
|
" (\"scaler\", num_scaler),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Заполнение пропущенных значений\n",
|
|||
|
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
|||
|
"# Унитарное кодирование\n",
|
|||
|
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
|||
|
"# Конвейер для обработки категориальных данных\n",
|
|||
|
"preprocessing_cat = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"imputer\", cat_imputer),\n",
|
|||
|
" (\"encoder\", cat_encoder),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Трансформер для предобработки признаков\n",
|
|||
|
"features_preprocessing = ColumnTransformer(\n",
|
|||
|
" verbose_feature_names_out=False,\n",
|
|||
|
" transformers=[\n",
|
|||
|
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
|
|||
|
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
|
|||
|
" ],\n",
|
|||
|
" remainder=\"passthrough\"\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Основной конвейер предобработки данных\n",
|
|||
|
"pipeline_end = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" ]\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Демонстрация работы конвейера:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 384,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Cost</th>\n",
|
|||
|
" <th>Shares</th>\n",
|
|||
|
" <th>Value ($)</th>\n",
|
|||
|
" <th>Shares Total</th>\n",
|
|||
|
" <th>Year</th>\n",
|
|||
|
" <th>Month</th>\n",
|
|||
|
" <th>Day</th>\n",
|
|||
|
" <th>Insider Trading_DENHOLM ROBYN M</th>\n",
|
|||
|
" <th>Insider Trading_Kirkhorn Zachary</th>\n",
|
|||
|
" <th>Insider Trading_Musk Elon</th>\n",
|
|||
|
" <th>Insider Trading_Musk Kimbal</th>\n",
|
|||
|
" <th>Insider Trading_Taneja Vaibhav</th>\n",
|
|||
|
" <th>Insider Trading_Wilson-Thompson Kathleen</th>\n",
|
|||
|
" <th>Relationship_Chief Accounting Officer</th>\n",
|
|||
|
" <th>Relationship_Chief Financial Officer</th>\n",
|
|||
|
" <th>Relationship_Director</th>\n",
|
|||
|
" <th>Relationship_SVP Powertrain and Energy Eng.</th>\n",
|
|||
|
" <th>Transaction_Sale</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>-0.966516</td>\n",
|
|||
|
" <td>-0.361759</td>\n",
|
|||
|
" <td>-0.450022</td>\n",
|
|||
|
" <td>-0.343599</td>\n",
|
|||
|
" <td>0.715678</td>\n",
|
|||
|
" <td>-0.506108</td>\n",
|
|||
|
" <td>-0.400623</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>-1.074894</td>\n",
|
|||
|
" <td>1.225216</td>\n",
|
|||
|
" <td>-0.414725</td>\n",
|
|||
|
" <td>-0.319938</td>\n",
|
|||
|
" <td>-1.397276</td>\n",
|
|||
|
" <td>0.801338</td>\n",
|
|||
|
" <td>0.906673</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>-1.074894</td>\n",
|
|||
|
" <td>1.211753</td>\n",
|
|||
|
" <td>-0.415027</td>\n",
|
|||
|
" <td>-0.320141</td>\n",
|
|||
|
" <td>-1.397276</td>\n",
|
|||
|
" <td>1.062828</td>\n",
|
|||
|
" <td>-0.098939</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>1.167142</td>\n",
|
|||
|
" <td>0.037499</td>\n",
|
|||
|
" <td>1.023612</td>\n",
|
|||
|
" <td>-0.325853</td>\n",
|
|||
|
" <td>-1.397276</td>\n",
|
|||
|
" <td>1.062828</td>\n",
|
|||
|
" <td>-0.501184</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>1.217886</td>\n",
|
|||
|
" <td>-0.075287</td>\n",
|
|||
|
" <td>0.632973</td>\n",
|
|||
|
" <td>-0.330205</td>\n",
|
|||
|
" <td>-1.397276</td>\n",
|
|||
|
" <td>1.062828</td>\n",
|
|||
|
" <td>-0.501184</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>5</th>\n",
|
|||
|
" <td>0.505872</td>\n",
|
|||
|
" <td>-0.361021</td>\n",
|
|||
|
" <td>-0.443679</td>\n",
|
|||
|
" <td>-0.343698</td>\n",
|
|||
|
" <td>0.715678</td>\n",
|
|||
|
" <td>-0.767598</td>\n",
|
|||
|
" <td>1.308918</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>6</th>\n",
|
|||
|
" <td>-1.088674</td>\n",
|
|||
|
" <td>-0.357532</td>\n",
|
|||
|
" <td>-0.450389</td>\n",
|
|||
|
" <td>-0.342863</td>\n",
|
|||
|
" <td>0.715678</td>\n",
|
|||
|
" <td>0.278360</td>\n",
|
|||
|
" <td>-0.903429</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>7</th>\n",
|
|||
|
" <td>-0.692146</td>\n",
|
|||
|
" <td>-0.355855</td>\n",
|
|||
|
" <td>-0.445383</td>\n",
|
|||
|
" <td>-0.343220</td>\n",
|
|||
|
" <td>0.715678</td>\n",
|
|||
|
" <td>0.801338</td>\n",
|
|||
|
" <td>1.409480</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>8</th>\n",
|
|||
|
" <td>-1.088674</td>\n",
|
|||
|
" <td>-0.361181</td>\n",
|
|||
|
" <td>-0.450389</td>\n",
|
|||
|
" <td>-0.343649</td>\n",
|
|||
|
" <td>-1.397276</td>\n",
|
|||
|
" <td>1.062828</td>\n",
|
|||
|
" <td>-0.903429</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>9</th>\n",
|
|||
|
" <td>1.091997</td>\n",
|
|||
|
" <td>-0.204531</td>\n",
|
|||
|
" <td>0.114712</td>\n",
|
|||
|
" <td>1.538166</td>\n",
|
|||
|
" <td>0.715678</td>\n",
|
|||
|
" <td>-1.029087</td>\n",
|
|||
|
" <td>1.208357</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Cost Shares Value ($) Shares Total Year Month Day \\\n",
|
|||
|
"0 -0.966516 -0.361759 -0.450022 -0.343599 0.715678 -0.506108 -0.400623 \n",
|
|||
|
"1 -1.074894 1.225216 -0.414725 -0.319938 -1.397276 0.801338 0.906673 \n",
|
|||
|
"2 -1.074894 1.211753 -0.415027 -0.320141 -1.397276 1.062828 -0.098939 \n",
|
|||
|
"3 1.167142 0.037499 1.023612 -0.325853 -1.397276 1.062828 -0.501184 \n",
|
|||
|
"4 1.217886 -0.075287 0.632973 -0.330205 -1.397276 1.062828 -0.501184 \n",
|
|||
|
"5 0.505872 -0.361021 -0.443679 -0.343698 0.715678 -0.767598 1.308918 \n",
|
|||
|
"6 -1.088674 -0.357532 -0.450389 -0.342863 0.715678 0.278360 -0.903429 \n",
|
|||
|
"7 -0.692146 -0.355855 -0.445383 -0.343220 0.715678 0.801338 1.409480 \n",
|
|||
|
"8 -1.088674 -0.361181 -0.450389 -0.343649 -1.397276 1.062828 -0.903429 \n",
|
|||
|
"9 1.091997 -0.204531 0.114712 1.538166 0.715678 -1.029087 1.208357 \n",
|
|||
|
"\n",
|
|||
|
" Insider Trading_DENHOLM ROBYN M Insider Trading_Kirkhorn Zachary \\\n",
|
|||
|
"0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 \n",
|
|||
|
"5 0.0 0.0 \n",
|
|||
|
"6 0.0 0.0 \n",
|
|||
|
"7 0.0 0.0 \n",
|
|||
|
"8 0.0 0.0 \n",
|
|||
|
"9 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Insider Trading_Musk Elon Insider Trading_Musk Kimbal \\\n",
|
|||
|
"0 0.0 0.0 \n",
|
|||
|
"1 1.0 0.0 \n",
|
|||
|
"2 1.0 0.0 \n",
|
|||
|
"3 1.0 0.0 \n",
|
|||
|
"4 1.0 0.0 \n",
|
|||
|
"5 0.0 0.0 \n",
|
|||
|
"6 0.0 0.0 \n",
|
|||
|
"7 0.0 0.0 \n",
|
|||
|
"8 0.0 0.0 \n",
|
|||
|
"9 1.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Insider Trading_Taneja Vaibhav Insider Trading_Wilson-Thompson Kathleen \\\n",
|
|||
|
"0 1.0 0.0 \n",
|
|||
|
"1 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 \n",
|
|||
|
"5 0.0 0.0 \n",
|
|||
|
"6 1.0 0.0 \n",
|
|||
|
"7 0.0 0.0 \n",
|
|||
|
"8 1.0 0.0 \n",
|
|||
|
"9 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Relationship_Chief Accounting Officer \\\n",
|
|||
|
"0 1.0 \n",
|
|||
|
"1 0.0 \n",
|
|||
|
"2 0.0 \n",
|
|||
|
"3 0.0 \n",
|
|||
|
"4 0.0 \n",
|
|||
|
"5 0.0 \n",
|
|||
|
"6 1.0 \n",
|
|||
|
"7 0.0 \n",
|
|||
|
"8 1.0 \n",
|
|||
|
"9 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Relationship_Chief Financial Officer Relationship_Director \\\n",
|
|||
|
"0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 \n",
|
|||
|
"5 0.0 0.0 \n",
|
|||
|
"6 0.0 0.0 \n",
|
|||
|
"7 0.0 0.0 \n",
|
|||
|
"8 0.0 0.0 \n",
|
|||
|
"9 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Relationship_SVP Powertrain and Energy Eng. Transaction_Sale \n",
|
|||
|
"0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 \n",
|
|||
|
"3 0.0 1.0 \n",
|
|||
|
"4 0.0 1.0 \n",
|
|||
|
"5 1.0 1.0 \n",
|
|||
|
"6 0.0 0.0 \n",
|
|||
|
"7 1.0 1.0 \n",
|
|||
|
"8 0.0 0.0 \n",
|
|||
|
"9 0.0 1.0 "
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 384,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Применение конвейера\n",
|
|||
|
"preprocessing_result = pipeline_end.fit_transform(X_df_train)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"preprocessed_df.head(10)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Обучение моделей:\n",
|
|||
|
"\n",
|
|||
|
"Оценка результатов обучения:\n",
|
|||
|
"1. **Случайный лес (Random Forest)**:\n",
|
|||
|
" - Показатели:\n",
|
|||
|
" - Средний балл: 0.9993.\n",
|
|||
|
" - Стандартное отклонение: 0.00046.\n",
|
|||
|
" - Вывод: Очень высокая точность, что свидетельствует о хорошей способности модели к обобщению. Низкое значение стандартного отклонения указывает на стабильность модели.\n",
|
|||
|
"2. **Линейная регрессия (Linear Regression)**:\n",
|
|||
|
" - Показатели:\n",
|
|||
|
" - Средний балл: 1.0.\n",
|
|||
|
" - Стандартное отклонение: 0.0.\n",
|
|||
|
" - Вывод: Идеальная точность, однако есть вероятность переобучения, так как стандартное отклонение равно 0. Это может указывать на то, что модель идеально подгоняет данные, но может не работать на новых данных.\n",
|
|||
|
"3. **Градиентный бустинг (Gradient Boosting)**:\n",
|
|||
|
" - Показатели:\n",
|
|||
|
" - Средний балл: 0.9998.\n",
|
|||
|
" - Стандартное отклонение: 0.00014.\n",
|
|||
|
" - Вывод: Отличные результаты с высокой точностью и низкой вариабельностью. Модель также демонстрирует хорошую устойчивость."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 385,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
|
|||
|
" return fit_method(estimator, *args, **kwargs)\n",
|
|||
|
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
|
|||
|
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
|
|||
|
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Модель: Random Forest\n",
|
|||
|
"\tmean_score: 0.9992580181099008\n",
|
|||
|
"\tstd_dev: 0.0004834744839371662\n",
|
|||
|
"\n",
|
|||
|
"Модель: Linear Regression\n",
|
|||
|
"\tmean_score: 1.0\n",
|
|||
|
"\tstd_dev: 0.0\n",
|
|||
|
"\n",
|
|||
|
"Модель: Gradient Boosting\n",
|
|||
|
"\tmean_score: 0.9997687065029746\n",
|
|||
|
"\tstd_dev: 0.00014193622424523165\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n",
|
|||
|
" warnings.warn(\n",
|
|||
|
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n",
|
|||
|
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\sklearn\\ensemble\\_gb.py:668: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
|
|||
|
" y = column_or_1d(y, warn=True) # TODO: Is this still required?\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor\n",
|
|||
|
"from sklearn.linear_model import LinearRegression\n",
|
|||
|
"from sklearn.model_selection import cross_val_score\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Обучить модели\n",
|
|||
|
"def train_models(X: DataFrame, y: DataFrame, \n",
|
|||
|
" models: dict[str, Any]) -> dict[str, dict[str, Any]]:\n",
|
|||
|
" results: dict[str, dict[str, Any]] = {}\n",
|
|||
|
" \n",
|
|||
|
" for model_name, model in models.items():\n",
|
|||
|
" # Создание конвейера для текущей модели\n",
|
|||
|
" model_pipeline = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"features_preprocessing\", features_preprocessing),\n",
|
|||
|
" (\"model\", model)\n",
|
|||
|
" ]\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" # Обучаем модель и вычисляем кросс-валидацию\n",
|
|||
|
" scores = cross_val_score(model_pipeline, X, y, cv=5) # 5-кратная кросс-валидация\n",
|
|||
|
" \n",
|
|||
|
" # Вычисление метрик для текущей модели\n",
|
|||
|
" metrics_dict: dict[str, Any] = {\n",
|
|||
|
" \"mean_score\": scores.mean(),\n",
|
|||
|
" \"std_dev\": scores.std()\n",
|
|||
|
" }\n",
|
|||
|
" \n",
|
|||
|
" # Сохранениерезультатов\n",
|
|||
|
" results[model_name] = metrics_dict\n",
|
|||
|
" \n",
|
|||
|
" return results\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Выбранные модели для регрессии\n",
|
|||
|
"models_regression: dict[str, Any] = {\n",
|
|||
|
" \"Random Forest\": RandomForestRegressor(),\n",
|
|||
|
" \"Linear Regression\": LinearRegression(),\n",
|
|||
|
" \"Gradient Boosting\": GradientBoostingRegressor(),\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"results: dict[str, Any] = train_models(X_df_train, y_df_train, models_regression)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод результатов\n",
|
|||
|
"for model_name, metrics_dict in results.items():\n",
|
|||
|
" print(f\"Модель: {model_name}\")\n",
|
|||
|
" for metric_name, value in metrics_dict.items():\n",
|
|||
|
" print(f\"\\t{metric_name}: {value}\")\n",
|
|||
|
" print()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Проверка на тестовом наборе данных:\n",
|
|||
|
"\n",
|
|||
|
"Оценка результатов обучения:\n",
|
|||
|
"1. Случайный лес (Random Forest):\n",
|
|||
|
" - Показатели:\n",
|
|||
|
" - MAE (обучение): 1.858\n",
|
|||
|
" - MAE (тест): 4.489\n",
|
|||
|
" - MSE (обучение): 10.959\n",
|
|||
|
" - MSE (тест): 62.649\n",
|
|||
|
" - R² (обучение): 0.9999\n",
|
|||
|
" - R² (тест): 0.9997\n",
|
|||
|
" - STD (обучение): 3.310\n",
|
|||
|
" - STD (тест): 7.757\n",
|
|||
|
" - Вывод: Случайный лес показывает великолепные значения R2 на обучающей и тестовой выборках, что свидетельствует о сильной способности к обобщению. Однако MAE и MSE на тестовой выборке значительно выше, чем на обучающей, что может указывать на некоторые проблемы с переобучением.\n",
|
|||
|
"2. Линейная регрессия (Linear Regression):\n",
|
|||
|
" - Показатели:\n",
|
|||
|
" - MAE (обучение): 3.069e-13\n",
|
|||
|
" - MAE (тест): 2.762e-13\n",
|
|||
|
" - MSE (обучение): 1.437e-25\n",
|
|||
|
" - MSE (тест): 1.196e-25\n",
|
|||
|
" - R² (обучение): 1.0\n",
|
|||
|
" - R² (тест): 1.0\n",
|
|||
|
" - STD (обучение): 3.730e-13\n",
|
|||
|
" - STD (тест): 3.444e-13\n",
|
|||
|
" - Вывод: Высокие показатели точности и нулевые ошибки (MAE, MSE) указывают на то, что модель идеально подгоняет данные как на обучающей, так и на тестовой выборках. Однако это также может быть признаком переобучения.\n",
|
|||
|
"3. Градиентный бустинг (Gradient Boosting):\n",
|
|||
|
" - Показатели:\n",
|
|||
|
" - MAE (обучение): 0.156\n",
|
|||
|
" - MAE (тест): 3.027\n",
|
|||
|
" - MSE (обучение): 0.075\n",
|
|||
|
" - MSE (тест): 41.360\n",
|
|||
|
" - R² (обучение): 0.9999996\n",
|
|||
|
" - R² (тест): 0.9998\n",
|
|||
|
" - STD (обучение): 0.274\n",
|
|||
|
" - STD (тест): 6.399\n",
|
|||
|
" - Вывод: Градиентный бустинг демонстрирует отличные результаты на обучающей выборке, однако MAE и MSE на тестовой выборке довольно высокие, что может указывать на определенное переобучение или необходимость улучшения настройки модели."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 386,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Модель: Random Forest\n",
|
|||
|
"\tMAE_train: 1.955516935483828\n",
|
|||
|
"\tMAE_test: 4.46537187499996\n",
|
|||
|
"\tMSE_train: 11.287871282983637\n",
|
|||
|
"\tMSE_test: 66.47081479843644\n",
|
|||
|
"\tR2_train: 0.9999449583585838\n",
|
|||
|
"\tR2_test: 0.9996258659651619\n",
|
|||
|
"\tSTD_train: 3.351830348079478\n",
|
|||
|
"\tSTD_test: 8.067958792345765\n",
|
|||
|
"\n",
|
|||
|
"Модель: Linear Regression\n",
|
|||
|
"\tMAE_train: 3.0690862038154006e-13\n",
|
|||
|
"\tMAE_test: 2.761679773755077e-13\n",
|
|||
|
"\tMSE_train: 1.4370485712253764e-25\n",
|
|||
|
"\tMSE_test: 1.19585889812782e-25\n",
|
|||
|
"\tR2_train: 1.0\n",
|
|||
|
"\tR2_test: 1.0\n",
|
|||
|
"\tSTD_train: 3.7295840825107354e-13\n",
|
|||
|
"\tSTD_test: 3.4438670391637766e-13\n",
|
|||
|
"\n",
|
|||
|
"Модель: Gradient Boosting\n",
|
|||
|
"\tMAE_train: 0.15613772760448247\n",
|
|||
|
"\tMAE_test: 2.9760510050502877\n",
|
|||
|
"\tMSE_train: 0.07499640211231862\n",
|
|||
|
"\tMSE_test: 38.91708171007616\n",
|
|||
|
"\tR2_train: 0.9999996343043813\n",
|
|||
|
"\tR2_test: 0.9997809534176997\n",
|
|||
|
"\tSTD_train: 0.2738547098596601\n",
|
|||
|
"\tSTD_test: 6.197132274535746\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Оценка качества различных моделей на основе метрик\n",
|
|||
|
"def evaluate_models(models: dict[str, Any], \n",
|
|||
|
" pipeline_end: Pipeline, \n",
|
|||
|
" X_train: DataFrame, y_train, \n",
|
|||
|
" X_test: DataFrame, y_test) -> dict[str, dict[str, Any]]:\n",
|
|||
|
" results: dict[str, dict[str, Any]] = {}\n",
|
|||
|
" \n",
|
|||
|
" for model_name, model in models.items():\n",
|
|||
|
" # Создание конвейера для текущей модели\n",
|
|||
|
" model_pipeline = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"pipeline\", pipeline_end), \n",
|
|||
|
" (\"model\", model),\n",
|
|||
|
" ]\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" # Обучение текущей модели\n",
|
|||
|
" model_pipeline.fit(X_train, y_train)\n",
|
|||
|
"\n",
|
|||
|
" # Предсказание для обучающей и тестовой выборки\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train)\n",
|
|||
|
" y_test_predict = model_pipeline.predict(X_test)\n",
|
|||
|
"\n",
|
|||
|
" # Вычисление метрик для текущей модели\n",
|
|||
|
" metrics_dict: dict[str, Any] = {\n",
|
|||
|
" \"MAE_train\": metrics.mean_absolute_error(y_train, y_train_predict),\n",
|
|||
|
" \"MAE_test\": metrics.mean_absolute_error(y_test, y_test_predict),\n",
|
|||
|
" \"MSE_train\": metrics.mean_squared_error(y_train, y_train_predict),\n",
|
|||
|
" \"MSE_test\": metrics.mean_squared_error(y_test, y_test_predict),\n",
|
|||
|
" \"R2_train\": metrics.r2_score(y_train, y_train_predict),\n",
|
|||
|
" \"R2_test\": metrics.r2_score(y_test, y_test_predict),\n",
|
|||
|
" \"STD_train\": np.std(y_train - y_train_predict),\n",
|
|||
|
" \"STD_test\": np.std(y_test - y_test_predict),\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" # Сохранение результатов\n",
|
|||
|
" results[model_name] = metrics_dict\n",
|
|||
|
" \n",
|
|||
|
" return results\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"y_train = np.ravel(y_df_train) \n",
|
|||
|
"y_test = np.ravel(y_df_test) \n",
|
|||
|
"\n",
|
|||
|
"results: dict[str, dict[str, Any]] = evaluate_models(models_regression,\n",
|
|||
|
" pipeline_end,\n",
|
|||
|
" X_df_train, y_train,\n",
|
|||
|
" X_df_test, y_test)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод результатов\n",
|
|||
|
"for model_name, metrics_dict in results.items():\n",
|
|||
|
" print(f\"Модель: {model_name}\")\n",
|
|||
|
" for metric_name, value in metrics_dict.items():\n",
|
|||
|
" print(f\"\\t{metric_name}: {value}\")\n",
|
|||
|
" print()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Подбор гиперпараметров:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 387,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Fitting 3 folds for each of 36 candidates, totalling 108 fits\n",
|
|||
|
"Лучшие параметры: {'max_depth': 10, 'min_samples_split': 5, 'n_estimators': 50}\n",
|
|||
|
"Лучший результат (MSE): 196.9489804872991\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.model_selection import GridSearchCV\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Применение конвейера к данным\n",
|
|||
|
"X_train_processing_result = pipeline_end.fit_transform(X_df_train)\n",
|
|||
|
"X_test_processing_result = pipeline_end.transform(X_df_test)\n",
|
|||
|
"\n",
|
|||
|
"# Создание и настройка модели случайного леса\n",
|
|||
|
"model = RandomForestRegressor()\n",
|
|||
|
"\n",
|
|||
|
"# Установка параметров для поиска по сетке\n",
|
|||
|
"param_grid: dict[str, list[int | None]] = {\n",
|
|||
|
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
|
|||
|
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
|
|||
|
" 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Подбор гиперпараметров с помощью поиска по сетке\n",
|
|||
|
"grid_search = GridSearchCV(estimator=model, \n",
|
|||
|
" param_grid=param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели на тренировочных данных\n",
|
|||
|
"grid_search.fit(X_train_processing_result, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# Результаты подбора гиперпараметров\n",
|
|||
|
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
|
|||
|
"# Меняем знак, так как берем отрицательное значение среднеквадратичной ошибки\n",
|
|||
|
"print(\"Лучший результат (MSE):\", -grid_search.best_score_)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Сравнение наборов гиперпараметров:\n",
|
|||
|
"\n",
|
|||
|
"Результаты анализа показывают, что параметры из старой сетки обеспечивают значительно лучшее качество модели. Среднеквадратическая ошибка (MSE) на кросс-валидации для старых параметров составила 179.369, что существенно ниже, чем для новых параметров (1290.656). На тестовой выборке модель с новыми параметрами показала MSE 172.574, что сопоставимо с результатами модели со старыми параметрами, однако этот результат является случайным, так как новые параметры продемонстрировали плохую кросс-валидационную ошибку, указывая на недообучение. Таким образом, параметры из старой сетки более предпочтительны, так как они обеспечивают лучшее обобщение и меньшую ошибку."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 388,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Fitting 3 folds for each of 36 candidates, totalling 108 fits\n",
|
|||
|
"Старые параметры: {'max_depth': 10, 'min_samples_split': 5, 'n_estimators': 50}\n",
|
|||
|
"Лучший результат (MSE) на старых параметрах: 184.14248778487732\n",
|
|||
|
"\n",
|
|||
|
"Новые параметры: {'max_depth': 5, 'min_samples_split': 10, 'n_estimators': 50}\n",
|
|||
|
"Лучший результат (MSE) на новых параметрах: 1283.4356458868208\n",
|
|||
|
"Среднеквадратическая ошибка (MSE) на тестовых данных: 159.03284823315155\n",
|
|||
|
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 12.610822662822262\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1sAAAHWCAYAAACBjZMqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3hTZfvA8W+SNk2b7j1pgbKHyBDZoMhwIjhBBQf4U1TciggiqCjgAtfrqy84cKLiHigiIAjIlj0KpXsl3SPj/P44NFI6aEvbpO39uS4uTc5zzrmbJum5z/M896NRFEVBCCGEEEIIIUSD0jo7ACGEEEIIIYRoiSTZEkIIIYQQQohGIMmWEEIIIYQQQjQCSbaEEEIIIYQQohFIsiWEEEIIIYQQjUCSLSGEEEIIIYRoBJJsCSGEEEIIIUQjkGRLCCGEEEIIIRqBJFtCCCGEEEII0Qgk2RJCCCGEEI3uu+++Y+fOnY7Hq1atYu/evc4LSIgmIMmWEK3A0aNHufPOO2nXrh0GgwFfX18GDRrEq6++SnFxsbPDE0II0Qrs2bOHGTNmcPjwYf766y/+7//+j/z8fGeHJUSj0iiKojg7CCFE4/n++++59tpr8fDw4JZbbqF79+6UlZWxYcMGvvjiC6ZMmcLbb7/t7DCFEEK0cJmZmQwcOJAjR44AMH78eL744gsnRyVE45JkS4gWLCEhgZ49exIdHc2aNWuIiIiosP3IkSN8//33zJgxw0kRCiGEaE1KS0v5559/8PLyokuXLs4OR4hGJ8MIhWjBFi5cSEFBAe+++26lRAsgPj6+QqKl0Wi45557WLFiBZ06dcJgMNCnTx/WrVtXYb8TJ05w991306lTJzw9PQkKCuLaa6/l+PHjFdotX74cjUbj+Ofl5UWPHj145513KrSbMmUK3t7eleJbuXIlGo2GtWvXVnh+8+bNjBkzBj8/P7y8vBg2bBh//vlnhTZz585Fo9GQlZVV4fm///4bjUbD8uXLK5w/Li6uQruTJ0/i6emJRqOp9HP9+OOPDBkyBKPRiI+PD5dddlmt5h2Uvx7r1q3jzjvvJCgoCF9fX2655RZMJlOl9rU5z+7du5kyZYpjiGh4eDi33XYb2dnZVcYQFxdX4XdS/u/01zguLo7LL7+8xp/l+PHjaDQaFi9eXGlb9+7dGT58uOPx2rVr0Wg0rFy5strjnfk7eOqpp9Bqtfz2228V2k2bNg29Xs+uXbtqjE+j0TB37twKzy1atAiNRlMhtpr2r+7f6XGe/jq8/PLLxMbG4unpybBhw/jnn38qHffAgQNcc801BAYGYjAY6Nu3L998802VMUyZMqXK80+ZMqVS2x9//JFhw4bh4+ODr68v/fr146OPPnJsHz58eKWf+9lnn0Wr1VZot379eq699lratGmDh4cHMTExPPDAA5WGG8+dO5euXbvi7e2Nr68vF154IatWrarQprbHqsvnf/jw4XTv3r1S28WLF1f6rJ7tfVz+viw//v79+/H09OSWW26p0G7Dhg3odDoee+yxao8FtXtN6hL/119/zWWXXUZkZCQeHh60b9+e+fPnY7PZKuxb1Xu9/LumPt9ddf19nPm+2rp1q+O9WlWcHh4e9OnThy5dutTpMylEc+Xm7ACEEI3n22+/pV27dgwcOLDW+/zxxx98+umn3HfffXh4ePDGG28wZswYtmzZ4rhI2Lp1Kxs3buSGG24gOjqa48eP8+abbzJ8+HD27duHl5dXhWO+/PLLBAcHk5eXx//+9z+mTp1KXFwcI0eOrPPPtGbNGsaOHUufPn0cF+TLli3joosuYv369VxwwQV1PmZV5syZQ0lJSaXnP/jgAyZPnszo0aN54YUXKCoq4s0332Tw4MHs2LGjUtJWlXvuuQd/f3/mzp3LwYMHefPNNzlx4oTj4q8u51m9ejXHjh3j1ltvJTw8nL179/L222+zd+9e/vrrr0oXPABDhgxh2rRpgHqB+dxzz9X/hWokTz75JN9++y233347e/bswcfHh59//pn//ve/zJ8/n/POO69OxzObzSxYsKBO+1xyySWVLrxffPHFKhPj999/n/z8fKZPn05JSQmvvvoqF110EXv27CEsLAyAvXv3MmjQIKKionj88ccxGo189tlnjBs3ji+++IKrr7660nE9PDwq3Jy44447KrVZvnw5t912G926dWPmzJn4+/uzY8cOfvrpJyZOnFjlz7Zs2TKefPJJXnzxxQptPv/8c4qKirjrrrsICgpiy5YtLF26lKSkJD7//HNHu8LCQq6++mri4uIoLi5m+fLlTJgwgU2bNjk+g7U9lqvo0qUL8+fP55FHHuGaa67hyiuvpLCwkClTptC5c2fmzZtX4/61eU3qYvny5Xh7e/Pggw/i7e3NmjVrmDNnDnl5eSxatKjOx2uI767aOFtSWq4+n0khmiVFCNEi5ebmKoBy1VVX1XofQAGUv//+2/HciRMnFIPBoFx99dWO54qKiirtu2nTJgVQ3n//fcdzy5YtUwAlISHB8dyhQ4cUQFm4cKHjucmTJytGo7HSMT///HMFUH7//XdFURTFbrcrHTp0UEaPHq3Y7fYK8bRt21a55JJLHM899dRTCqBkZmZWOObWrVsVQFm2bFmF88fGxjoe//PPP4pWq1XGjh1bIf78/HzF399fmTp1aoVjpqWlKX5+fpWeP1P569GnTx+lrKzM8fzChQsVQPn666/rfJ6qfhcff/yxAijr1q2rtC0qKkq59dZbHY9///33Cq+xoihKbGysctlll9X4syQkJCiAsmjRokrbunXrpgwbNqzSOT7//PNqj3fm70BRFGXPnj2KXq9X7rjjDsVkMilRUVFK3759FYvFUmNsiqK+l5966inH40cffVQJDQ1V+vTpUyG2mvafPn16pecvu+yyCnGWvw6enp5KUlKS4/nNmzcrgPLAAw84nrv44ouVHj16KCUlJY7n7Ha7MnDgQKVDhw6VzjVx4kTF29u7wnNGo1GZPHmy47HZbFZ8fHyU/v37K8XFxRXanv4ZGTZsmOPn/v777xU3NzfloYceqnTOqt5PCxYsUDQajXLixIlK28plZGQogLJ48eI6H6u2n//yn6Nbt26V2i5atKjSd83Z3sdVvfdtNpsyePBgJSwsTMnKylKmT5+uuLm5KVu3bq32ONWp6jWpS/xVvX533nmn4uXlVeE9pNFolDlz5lRod+Z3b12+U+r6+zj98/TDDz8ogDJmzBjlzEvMc/1MCtFcyTBCIVqovLw8AHx8fOq034ABA+jTp4/jcZs2bbjqqqv4+eefHcNXPD09HdstFgvZ2dnEx8fj7+/P9u3bKx3TZDKRlZXFsWPHePnll9HpdAwbNqxSu6ysrAr/zqxStXPnTg4fPszEiRPJzs52tCssLOTiiy9m3bp12O32Cvvk5ORUOGZubu5ZX4OZM2fSu3dvrr322grPr169GrPZzI033ljhmDqdjv79+/P777+f9digDoVzd3d3PL7rrrtwc3Pjhx9+qPN5Tv9dlJSUkJWVxYUXXghQ5e+irKwMDw+Ps8ZosVjIysoiOzsbq9VabbuioqJKv7czhzmVy8/PJysrC7PZfNbzgzoc8emnn+add95h9OjRZGVl8d577+HmVrdBGcnJySxdupTZs2dXOTyqIYwbN46oqCjH4wsuuID+/fs7fqc5OTmsWbOG6667zvE6lL++o0eP5vDhwyQnJ1c4ZklJCQaDocbzrl69mvz8fB5//PFKbavq1dyyZQvXXXcdEyZMqLJ35PT3U2FhIVlZWQwcOBBFUdixY0eFtuXvkaNHj/L888+j1WoZNGhQvY4FZ//8l7PZbJXaFhUVVdm2tu/jclqtluXLl1NQUMDYsWN54403mDlzJn379j3rvqefr7rXpC7xn/76lb9nhgwZQlFREQcOHHBsCw0NJSkpqca46vPdVdvfRzlFUZg5cyYTJkygf//+NbZtis+kEK5ChhEK0UL5+voC1LmsbocOHSo917FjR4qKisjMzCQ8PJzi4mIWLFjAsmXLSE5ORjmtzk5VyUzv3r0d/+/h4cFrr71WaVhNYWE
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1000x500 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Установка параметров для поиска по сетке для старых значений\n",
|
|||
|
"old_param_grid: dict[str, list[int | None]] = {\n",
|
|||
|
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
|
|||
|
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
|
|||
|
" 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Подбор гиперпараметров с помощью поиска по сетке для старых параметров\n",
|
|||
|
"old_grid_search = GridSearchCV(estimator=model, \n",
|
|||
|
" param_grid=old_param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', cv=3, n_jobs=-1, verbose=2)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели на тренировочных данных\n",
|
|||
|
"old_grid_search.fit(X_train_processing_result, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# Результаты подбора для старых параметров\n",
|
|||
|
"old_best_params = old_grid_search.best_params_\n",
|
|||
|
"# Меняем знак, так как берем отрицательное значение MSE\n",
|
|||
|
"old_best_mse = -old_grid_search.best_score_\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Установка параметров для поиска по сетке для новых значений\n",
|
|||
|
"new_param_grid: dict[str, list[int]] = {\n",
|
|||
|
" 'n_estimators': [50],\n",
|
|||
|
" 'max_depth': [5],\n",
|
|||
|
" 'min_samples_split': [10]\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Подбор гиперпараметров с помощью поиска по сетке для новых параметров\n",
|
|||
|
"new_grid_search = GridSearchCV(estimator=model, \n",
|
|||
|
" param_grid=new_param_grid,\n",
|
|||
|
" scoring='neg_mean_squared_error', cv=2)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели на тренировочных данных\n",
|
|||
|
"new_grid_search.fit(X_train_processing_result, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# Результаты подбора для новых параметров\n",
|
|||
|
"new_best_params = new_grid_search.best_params_\n",
|
|||
|
"# Меняем знак, так как берем отрицательное значение MSE\n",
|
|||
|
"new_best_mse = -new_grid_search.best_score_\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели с лучшими параметрами для новых значений\n",
|
|||
|
"model_best = RandomForestRegressor(**new_best_params)\n",
|
|||
|
"model_best.fit(X_train_processing_result, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# Прогнозирование на тестовой выборке\n",
|
|||
|
"y_pred = model_best.predict(X_test_processing_result)\n",
|
|||
|
"\n",
|
|||
|
"# Оценка производительности модели\n",
|
|||
|
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
|
|||
|
"rmse = np.sqrt(mse)\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Вывод результатов\n",
|
|||
|
"print(\"Старые параметры:\", old_best_params)\n",
|
|||
|
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
|
|||
|
"print(\"\\nНовые параметры:\", new_best_params)\n",
|
|||
|
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
|
|||
|
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
|
|||
|
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели с лучшими параметрами для старых значений\n",
|
|||
|
"model_old = RandomForestRegressor(**old_best_params)\n",
|
|||
|
"model_old.fit(X_train_processing_result, y_train)\n",
|
|||
|
"\n",
|
|||
|
"# Прогнозирование на тестовой выборке для старых параметров\n",
|
|||
|
"y_pred_old = model_old.predict(X_test_processing_result)\n",
|
|||
|
"\n",
|
|||
|
"# Визуализация ошибок\n",
|
|||
|
"plt.figure(figsize=(10, 5))\n",
|
|||
|
"plt.plot(y_test, label='Реальные значения', marker='o', linestyle='-', color='black')\n",
|
|||
|
"plt.plot(y_pred_old, label='Предсказанные значения (старые параметры)', marker='x', linestyle='--', color='blue')\n",
|
|||
|
"plt.plot(y_pred, label='Предсказанные значения (новые параметры)', marker='s', linestyle='--', color='orange')\n",
|
|||
|
"plt.xlabel('Объекты')\n",
|
|||
|
"plt.ylabel('Значения')\n",
|
|||
|
"plt.title('Сравнение реальных и предсказанных значений')\n",
|
|||
|
"plt.legend()\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Бизес-цель №2 (Задача классификации).\n",
|
|||
|
"\n",
|
|||
|
"### Достижимый уровень качества модели:\n",
|
|||
|
"**Основные метрики для классификации:**\n",
|
|||
|
"- **Accuracy (точность)** – показывает долю правильно классифицированных примеров среди всех наблюдений. Легко интерпретируется, но может быть недостаточно информативной для несбалансированных классов.\n",
|
|||
|
"- **F1-Score** – гармоническое среднее между точностью (precision) и полнотой (recall). Подходит для задач, где важно одновременно учитывать как ложные положительные, так и ложные отрицательные ошибки, особенно при несбалансированных классах.\n",
|
|||
|
"- **ROC AUC (Area Under the ROC Curve)** – отражает способность модели различать положительные и отрицательные классы на всех уровнях порога вероятности. Значение от 0.5 (случайное угадывание) до 1.0 (идеальная модель). Полезна для оценки модели на несбалансированных данных.\n",
|
|||
|
"- **Cohen's Kappa** – измеряет степень согласия между предсказаниями модели и истинными метками с учётом случайного угадывания. Значения варьируются от -1 (полное несогласие) до 1 (идеальное согласие). Удобна для оценки на несбалансированных данных.\n",
|
|||
|
"- **MCC (Matthews Correlation Coefficient)** – метрика корреляции между предсказаниями и истинными классами, учитывающая все типы ошибок (TP, TN, FP, FN). Значение варьируется от -1 (полная несоответствие) до 1 (идеальное совпадение). Отлично подходит для задач с несбалансированными классами.\n",
|
|||
|
"- **Confusion Matrix (матрица ошибок)** – матрица ошибок отражает распределение предсказаний модели по каждому из классов.\n",
|
|||
|
"\n",
|
|||
|
"---\n",
|
|||
|
"\n",
|
|||
|
"### Выбор ориентира:\n",
|
|||
|
"В качестве базовой модели для оценки качества предсказаний выбрано использование самой распространённой категории целевой переменной (\"Transaction\") в обучающей выборке. Этот подход, известный как \"most frequent class baseline\", заключается в том, что модель всегда предсказывает наиболее часто встречающийся тип транзакции.\n",
|
|||
|
"\n",
|
|||
|
"---"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Разбиение данных:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 389,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Самый частый класс: Sale\n",
|
|||
|
"Baseline Accuracy: 0.59375\n",
|
|||
|
"Baseline F1: 0.4424019607843137\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import accuracy_score, f1_score\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Определяем целевой признак и входные признаки\n",
|
|||
|
"y_feature: str = 'Transaction'\n",
|
|||
|
"X_features: list[str] = df.drop(columns=y_feature, axis=1).columns.tolist()\n",
|
|||
|
"\n",
|
|||
|
"# Разбиваем данные на обучающую и тестовую выборки\n",
|
|||
|
"X_df_train, X_df_test, y_df_train, y_df_test = split_into_train_test(\n",
|
|||
|
" df, \n",
|
|||
|
" stratify_colname=y_feature, \n",
|
|||
|
" frac_train=0.8, \n",
|
|||
|
" random_state=42 \n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Определяем самый частый класс\n",
|
|||
|
"most_frequent_class = y_df_train.mode().values[0][0]\n",
|
|||
|
"print(f\"Самый частый класс: {most_frequent_class}\")\n",
|
|||
|
"\n",
|
|||
|
"# Вычисляем предсказания базовой модели (все предсказания равны самому частому классу)\n",
|
|||
|
"baseline_predictions: list[str] = [most_frequent_class] * len(y_df_test)\n",
|
|||
|
"\n",
|
|||
|
"# Оцениваем базовую модель\n",
|
|||
|
"print('Baseline Accuracy:', accuracy_score(y_df_test, baseline_predictions))\n",
|
|||
|
"print('Baseline F1:', f1_score(y_df_test, baseline_predictions, average='weighted'))\n",
|
|||
|
"\n",
|
|||
|
"# Унитарное кодирование для целевого признака\n",
|
|||
|
"y_df_train = y_df_train['Transaction'].map({'Sale': 1, 'Option Exercise': 0})\n",
|
|||
|
"y_df_test = y_df_test['Transaction'].map({'Sale': 1, 'Option Exercise': 0})"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Выбор моделей обучения:\n",
|
|||
|
"\n",
|
|||
|
"Для обучения были выбраны следующие модели:\n",
|
|||
|
"1. **Случайный лес (Random Forest)**: Ансамблевая модель, которая использует множество решающих деревьев. Она хорошо справляется с нелинейными зависимостями и шумом в данных, а также обладает устойчивостью к переобучению.\n",
|
|||
|
"2. **Логистическая регрессия (Logistic Regression)**: Статистический метод для бинарной классификации, который моделирует зависимость между целевой переменной и независимыми признаками, используя логистическую функцию. Она проста в интерпретации и быстра в обучении.\n",
|
|||
|
"3. **Метод ближайших соседей (KNN)**: Алгоритм классификации, который предсказывает класс на основе ближайших k обучающих примеров. KNN интуитивно понятен и не требует обучения, но может быть медленным на больших данных и чувствительным к выбору параметров.\n",
|
|||
|
"\n",
|
|||
|
"---"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Построение конвейера:\n",
|
|||
|
"\n",
|
|||
|
"Конвейеры для обработки числовых и категориальных значений, а так же основной конвейер уже были построены ранее при решении задачи регрессии."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Демонстрация работы конвейера:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 390,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>Cost</th>\n",
|
|||
|
" <th>Shares</th>\n",
|
|||
|
" <th>Value ($)</th>\n",
|
|||
|
" <th>Shares Total</th>\n",
|
|||
|
" <th>Year</th>\n",
|
|||
|
" <th>Month</th>\n",
|
|||
|
" <th>Day</th>\n",
|
|||
|
" <th>Insider Trading_DENHOLM ROBYN M</th>\n",
|
|||
|
" <th>Insider Trading_Kirkhorn Zachary</th>\n",
|
|||
|
" <th>Insider Trading_Musk Elon</th>\n",
|
|||
|
" <th>Insider Trading_Musk Kimbal</th>\n",
|
|||
|
" <th>Insider Trading_Taneja Vaibhav</th>\n",
|
|||
|
" <th>Insider Trading_Wilson-Thompson Kathleen</th>\n",
|
|||
|
" <th>Relationship_Chief Accounting Officer</th>\n",
|
|||
|
" <th>Relationship_Chief Financial Officer</th>\n",
|
|||
|
" <th>Relationship_Director</th>\n",
|
|||
|
" <th>Relationship_SVP Powertrain and Energy Eng.</th>\n",
|
|||
|
" <th>Transaction_Sale</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>-0.966516</td>\n",
|
|||
|
" <td>-0.361759</td>\n",
|
|||
|
" <td>-0.450022</td>\n",
|
|||
|
" <td>-0.343599</td>\n",
|
|||
|
" <td>0.715678</td>\n",
|
|||
|
" <td>-0.506108</td>\n",
|
|||
|
" <td>-0.400623</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>-1.074894</td>\n",
|
|||
|
" <td>1.225216</td>\n",
|
|||
|
" <td>-0.414725</td>\n",
|
|||
|
" <td>-0.319938</td>\n",
|
|||
|
" <td>-1.397276</td>\n",
|
|||
|
" <td>0.801338</td>\n",
|
|||
|
" <td>0.906673</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>-1.074894</td>\n",
|
|||
|
" <td>1.211753</td>\n",
|
|||
|
" <td>-0.415027</td>\n",
|
|||
|
" <td>-0.320141</td>\n",
|
|||
|
" <td>-1.397276</td>\n",
|
|||
|
" <td>1.062828</td>\n",
|
|||
|
" <td>-0.098939</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>1.167142</td>\n",
|
|||
|
" <td>0.037499</td>\n",
|
|||
|
" <td>1.023612</td>\n",
|
|||
|
" <td>-0.325853</td>\n",
|
|||
|
" <td>-1.397276</td>\n",
|
|||
|
" <td>1.062828</td>\n",
|
|||
|
" <td>-0.501184</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>1.217886</td>\n",
|
|||
|
" <td>-0.075287</td>\n",
|
|||
|
" <td>0.632973</td>\n",
|
|||
|
" <td>-0.330205</td>\n",
|
|||
|
" <td>-1.397276</td>\n",
|
|||
|
" <td>1.062828</td>\n",
|
|||
|
" <td>-0.501184</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>5</th>\n",
|
|||
|
" <td>0.505872</td>\n",
|
|||
|
" <td>-0.361021</td>\n",
|
|||
|
" <td>-0.443679</td>\n",
|
|||
|
" <td>-0.343698</td>\n",
|
|||
|
" <td>0.715678</td>\n",
|
|||
|
" <td>-0.767598</td>\n",
|
|||
|
" <td>1.308918</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>6</th>\n",
|
|||
|
" <td>-1.088674</td>\n",
|
|||
|
" <td>-0.357532</td>\n",
|
|||
|
" <td>-0.450389</td>\n",
|
|||
|
" <td>-0.342863</td>\n",
|
|||
|
" <td>0.715678</td>\n",
|
|||
|
" <td>0.278360</td>\n",
|
|||
|
" <td>-0.903429</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>7</th>\n",
|
|||
|
" <td>-0.692146</td>\n",
|
|||
|
" <td>-0.355855</td>\n",
|
|||
|
" <td>-0.445383</td>\n",
|
|||
|
" <td>-0.343220</td>\n",
|
|||
|
" <td>0.715678</td>\n",
|
|||
|
" <td>0.801338</td>\n",
|
|||
|
" <td>1.409480</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>8</th>\n",
|
|||
|
" <td>-1.088674</td>\n",
|
|||
|
" <td>-0.361181</td>\n",
|
|||
|
" <td>-0.450389</td>\n",
|
|||
|
" <td>-0.343649</td>\n",
|
|||
|
" <td>-1.397276</td>\n",
|
|||
|
" <td>1.062828</td>\n",
|
|||
|
" <td>-0.903429</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>9</th>\n",
|
|||
|
" <td>1.091997</td>\n",
|
|||
|
" <td>-0.204531</td>\n",
|
|||
|
" <td>0.114712</td>\n",
|
|||
|
" <td>1.538166</td>\n",
|
|||
|
" <td>0.715678</td>\n",
|
|||
|
" <td>-1.029087</td>\n",
|
|||
|
" <td>1.208357</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" Cost Shares Value ($) Shares Total Year Month Day \\\n",
|
|||
|
"0 -0.966516 -0.361759 -0.450022 -0.343599 0.715678 -0.506108 -0.400623 \n",
|
|||
|
"1 -1.074894 1.225216 -0.414725 -0.319938 -1.397276 0.801338 0.906673 \n",
|
|||
|
"2 -1.074894 1.211753 -0.415027 -0.320141 -1.397276 1.062828 -0.098939 \n",
|
|||
|
"3 1.167142 0.037499 1.023612 -0.325853 -1.397276 1.062828 -0.501184 \n",
|
|||
|
"4 1.217886 -0.075287 0.632973 -0.330205 -1.397276 1.062828 -0.501184 \n",
|
|||
|
"5 0.505872 -0.361021 -0.443679 -0.343698 0.715678 -0.767598 1.308918 \n",
|
|||
|
"6 -1.088674 -0.357532 -0.450389 -0.342863 0.715678 0.278360 -0.903429 \n",
|
|||
|
"7 -0.692146 -0.355855 -0.445383 -0.343220 0.715678 0.801338 1.409480 \n",
|
|||
|
"8 -1.088674 -0.361181 -0.450389 -0.343649 -1.397276 1.062828 -0.903429 \n",
|
|||
|
"9 1.091997 -0.204531 0.114712 1.538166 0.715678 -1.029087 1.208357 \n",
|
|||
|
"\n",
|
|||
|
" Insider Trading_DENHOLM ROBYN M Insider Trading_Kirkhorn Zachary \\\n",
|
|||
|
"0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 \n",
|
|||
|
"5 0.0 0.0 \n",
|
|||
|
"6 0.0 0.0 \n",
|
|||
|
"7 0.0 0.0 \n",
|
|||
|
"8 0.0 0.0 \n",
|
|||
|
"9 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Insider Trading_Musk Elon Insider Trading_Musk Kimbal \\\n",
|
|||
|
"0 0.0 0.0 \n",
|
|||
|
"1 1.0 0.0 \n",
|
|||
|
"2 1.0 0.0 \n",
|
|||
|
"3 1.0 0.0 \n",
|
|||
|
"4 1.0 0.0 \n",
|
|||
|
"5 0.0 0.0 \n",
|
|||
|
"6 0.0 0.0 \n",
|
|||
|
"7 0.0 0.0 \n",
|
|||
|
"8 0.0 0.0 \n",
|
|||
|
"9 1.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Insider Trading_Taneja Vaibhav Insider Trading_Wilson-Thompson Kathleen \\\n",
|
|||
|
"0 1.0 0.0 \n",
|
|||
|
"1 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 \n",
|
|||
|
"5 0.0 0.0 \n",
|
|||
|
"6 1.0 0.0 \n",
|
|||
|
"7 0.0 0.0 \n",
|
|||
|
"8 1.0 0.0 \n",
|
|||
|
"9 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Relationship_Chief Accounting Officer \\\n",
|
|||
|
"0 1.0 \n",
|
|||
|
"1 0.0 \n",
|
|||
|
"2 0.0 \n",
|
|||
|
"3 0.0 \n",
|
|||
|
"4 0.0 \n",
|
|||
|
"5 0.0 \n",
|
|||
|
"6 1.0 \n",
|
|||
|
"7 0.0 \n",
|
|||
|
"8 1.0 \n",
|
|||
|
"9 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Relationship_Chief Financial Officer Relationship_Director \\\n",
|
|||
|
"0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 \n",
|
|||
|
"3 0.0 0.0 \n",
|
|||
|
"4 0.0 0.0 \n",
|
|||
|
"5 0.0 0.0 \n",
|
|||
|
"6 0.0 0.0 \n",
|
|||
|
"7 0.0 0.0 \n",
|
|||
|
"8 0.0 0.0 \n",
|
|||
|
"9 0.0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" Relationship_SVP Powertrain and Energy Eng. Transaction_Sale \n",
|
|||
|
"0 0.0 0.0 \n",
|
|||
|
"1 0.0 0.0 \n",
|
|||
|
"2 0.0 0.0 \n",
|
|||
|
"3 0.0 1.0 \n",
|
|||
|
"4 0.0 1.0 \n",
|
|||
|
"5 1.0 1.0 \n",
|
|||
|
"6 0.0 0.0 \n",
|
|||
|
"7 1.0 1.0 \n",
|
|||
|
"8 0.0 0.0 \n",
|
|||
|
"9 0.0 1.0 "
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 390,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Применение конвейера\n",
|
|||
|
"preprocessing_result = pipeline_end.fit_transform(X_df_train)\n",
|
|||
|
"preprocessed_df = pd.DataFrame(\n",
|
|||
|
" preprocessing_result,\n",
|
|||
|
" columns=pipeline_end.get_feature_names_out(),\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"preprocessed_df.head(10)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Проверка моделей:\n",
|
|||
|
"\n",
|
|||
|
"Оценка результатов обучения:\n",
|
|||
|
"1. **Случайный лес (Random Forest)**:\n",
|
|||
|
" - Показатели:\n",
|
|||
|
" - Precision (обучение): 1.0\n",
|
|||
|
" - Precision (тест): 1.0\n",
|
|||
|
" - Recall (обучение): 1.0\n",
|
|||
|
" - Recall (тест): 1.0\n",
|
|||
|
" - Accuracy (обучение): 1.0\n",
|
|||
|
" - Accuracy (тест): 1.0\n",
|
|||
|
" - F1 Score (обучение): 1.0\n",
|
|||
|
" - F1 Score (тест): 1.0\n",
|
|||
|
" - ROC AUC (тест): 1.0\n",
|
|||
|
" - Cohen Kappa (тест): 1.0\n",
|
|||
|
" - MCC (тест): 1.0\n",
|
|||
|
" - Confusion Matrix (тест):\n",
|
|||
|
" ```\n",
|
|||
|
" [[13, 0],\n",
|
|||
|
" [ 0, 19]]\n",
|
|||
|
" ```\n",
|
|||
|
" - Вывод: Случайный лес идеально справляется с задачей на обеих выборках. Однако столь высокие значения метрик на обучении и тесте могут указывать на переобучение модели.\n",
|
|||
|
"2. **Логистическая регрессия (Logistic Regression)**:\n",
|
|||
|
" - Показатели:\n",
|
|||
|
" - Precision (обучение): 1.0\n",
|
|||
|
" - Precision (тест): 1.0\n",
|
|||
|
" - Recall (обучение): 1.0\n",
|
|||
|
" - Recall (тест): 1.0\n",
|
|||
|
" - Accuracy (обучение): 1.0\n",
|
|||
|
" - Accuracy (тест): 1.0\n",
|
|||
|
" - F1 Score (обучение): 1.0\n",
|
|||
|
" - F1 Score (тест): 1.0\n",
|
|||
|
" - ROC AUC (тест): 1.0\n",
|
|||
|
" - Cohen Kappa (тест): 1.0\n",
|
|||
|
" - MCC (тест): 1.0\n",
|
|||
|
" - Confusion Matrix (тест):\n",
|
|||
|
" ```\n",
|
|||
|
" [[13, 0],\n",
|
|||
|
" [ 0, 19]]\n",
|
|||
|
" ```\n",
|
|||
|
" - Вывод: Логистическая регрессия также показывает идеальные результаты. Это может быть связано с линейной разделимостью данных.\n",
|
|||
|
"3. **Метод ближайших соседей (KNN)**:\n",
|
|||
|
" - Показатели:\n",
|
|||
|
" - Precision (обучение): 1.0\n",
|
|||
|
" - Precision (тест): 1.0\n",
|
|||
|
" - Recall (обучение): 0.95\n",
|
|||
|
" - Recall (тест): 0.947\n",
|
|||
|
" - Accuracy (обучение): 0.968\n",
|
|||
|
" - Accuracy (тест): 0.969\n",
|
|||
|
" - F1 Score (обучение): 0.974\n",
|
|||
|
" - F1 Score (тест): 0.973\n",
|
|||
|
" - ROC AUC (тест): 0.974\n",
|
|||
|
" - Cohen Kappa (тест): 0.936\n",
|
|||
|
" - MCC (тест): 0.938\n",
|
|||
|
" - Confusion Matrix (тест):\n",
|
|||
|
" ```\n",
|
|||
|
" [[13, 0],\n",
|
|||
|
" [ 1, 18]]\n",
|
|||
|
" ```\n",
|
|||
|
" - Вывод: Метод ближайших соседей показывает хорошие результаты, с небольшим снижением полноты на тестовой выборке. Это связано с особенностями алгоритма, который может быть чувствителен к выбросам и распределению данных.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 391,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Модель: RandomForestClassifier\n",
|
|||
|
"\tPrecision_train: 1.0\n",
|
|||
|
"\tPrecision_test: 1.0\n",
|
|||
|
"\tRecall_train: 1.0\n",
|
|||
|
"\tRecall_test: 1.0\n",
|
|||
|
"\tAccuracy_train: 1.0\n",
|
|||
|
"\tAccuracy_test: 1.0\n",
|
|||
|
"\tF1_train: 1.0\n",
|
|||
|
"\tF1_test: 1.0\n",
|
|||
|
"\tROC_AUC_test: 1.0\n",
|
|||
|
"\tCohen_kappa_test: 1.0\n",
|
|||
|
"\tMCC_test: 1.0\n",
|
|||
|
"\tConfusion_matrix: [[13 0]\n",
|
|||
|
" [ 0 19]]\n",
|
|||
|
"\n",
|
|||
|
"Модель: LogisticRegression\n",
|
|||
|
"\tPrecision_train: 1.0\n",
|
|||
|
"\tPrecision_test: 1.0\n",
|
|||
|
"\tRecall_train: 1.0\n",
|
|||
|
"\tRecall_test: 1.0\n",
|
|||
|
"\tAccuracy_train: 1.0\n",
|
|||
|
"\tAccuracy_test: 1.0\n",
|
|||
|
"\tF1_train: 1.0\n",
|
|||
|
"\tF1_test: 1.0\n",
|
|||
|
"\tROC_AUC_test: 1.0\n",
|
|||
|
"\tCohen_kappa_test: 1.0\n",
|
|||
|
"\tMCC_test: 1.0\n",
|
|||
|
"\tConfusion_matrix: [[13 0]\n",
|
|||
|
" [ 0 19]]\n",
|
|||
|
"\n",
|
|||
|
"Модель: KNN\n",
|
|||
|
"\tPrecision_train: 1.0\n",
|
|||
|
"\tPrecision_test: 1.0\n",
|
|||
|
"\tRecall_train: 0.95\n",
|
|||
|
"\tRecall_test: 0.9473684210526315\n",
|
|||
|
"\tAccuracy_train: 0.967741935483871\n",
|
|||
|
"\tAccuracy_test: 0.96875\n",
|
|||
|
"\tF1_train: 0.9743589743589743\n",
|
|||
|
"\tF1_test: 0.972972972972973\n",
|
|||
|
"\tROC_AUC_test: 0.9736842105263157\n",
|
|||
|
"\tCohen_kappa_test: 0.9359999999999999\n",
|
|||
|
"\tMCC_test: 0.9379228369755696\n",
|
|||
|
"\tConfusion_matrix: [[13 0]\n",
|
|||
|
" [ 1 18]]\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.ensemble import RandomForestClassifier\n",
|
|||
|
"from sklearn.linear_model import LogisticRegression\n",
|
|||
|
"from sklearn.neighbors import KNeighborsClassifier\n",
|
|||
|
"from sklearn import metrics\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Оценка качества различных моделей на основе метрик\n",
|
|||
|
"def evaluate_models(models: dict[str, Any], \n",
|
|||
|
" pipeline_end: Pipeline, \n",
|
|||
|
" X_train: DataFrame, y_train, \n",
|
|||
|
" X_test: DataFrame, y_test) -> dict[str, dict[str, Any]]:\n",
|
|||
|
" results: dict[str, dict[str, Any]] = {}\n",
|
|||
|
" \n",
|
|||
|
" for model_name, model in models.items():\n",
|
|||
|
" # Создание конвейера для текущей модели\n",
|
|||
|
" model_pipeline = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"pipeline\", pipeline_end), \n",
|
|||
|
" (\"model\", model),\n",
|
|||
|
" ]\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" # Обучение модели\n",
|
|||
|
" model_pipeline.fit(X_train, y_train)\n",
|
|||
|
" \n",
|
|||
|
" # Предсказание для обучающей и тестовой выборки\n",
|
|||
|
" y_train_predict = model_pipeline.predict(X_train)\n",
|
|||
|
" y_test_predict = model_pipeline.predict(X_test)\n",
|
|||
|
" \n",
|
|||
|
" # Вычисление метрик для текущей модели\n",
|
|||
|
" metrics_dict: dict[str, Any] = {\n",
|
|||
|
" \"Precision_train\": metrics.precision_score(y_train, y_train_predict),\n",
|
|||
|
" \"Precision_test\": metrics.precision_score(y_test, y_test_predict),\n",
|
|||
|
" \"Recall_train\": metrics.recall_score(y_train, y_train_predict),\n",
|
|||
|
" \"Recall_test\": metrics.recall_score(y_test, y_test_predict),\n",
|
|||
|
" \"Accuracy_train\": metrics.accuracy_score(y_train, y_train_predict),\n",
|
|||
|
" \"Accuracy_test\": metrics.accuracy_score(y_test, y_test_predict),\n",
|
|||
|
" \"F1_train\": metrics.f1_score(y_train, y_train_predict),\n",
|
|||
|
" \"F1_test\": metrics.f1_score(y_test, y_test_predict),\n",
|
|||
|
" \"ROC_AUC_test\": metrics.roc_auc_score(y_test, y_test_predict),\n",
|
|||
|
" \"Cohen_kappa_test\": metrics.cohen_kappa_score(y_test, y_test_predict),\n",
|
|||
|
" \"MCC_test\": metrics.matthews_corrcoef(y_test, y_test_predict),\n",
|
|||
|
" \"Confusion_matrix\": metrics.confusion_matrix(y_test, y_test_predict),\n",
|
|||
|
" }\n",
|
|||
|
" \n",
|
|||
|
" # Сохранение результатов\n",
|
|||
|
" results[model_name] = metrics_dict\n",
|
|||
|
" \n",
|
|||
|
" return results\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Выбранные модели для классификации\n",
|
|||
|
"models_classification: dict[str, Any] = {\n",
|
|||
|
" \"RandomForestClassifier\": RandomForestClassifier(random_state=42),\n",
|
|||
|
" \"LogisticRegression\": LogisticRegression(max_iter=1000),\n",
|
|||
|
" \"KNN\": KNeighborsClassifier(),\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"results: dict[str, dict[str, Any]] = evaluate_models(models_classification,\n",
|
|||
|
" pipeline_end,\n",
|
|||
|
" X_df_train, y_df_train,\n",
|
|||
|
" X_df_test, y_df_test)\n",
|
|||
|
"\n",
|
|||
|
"# Вывод результатов\n",
|
|||
|
"for model_name, metrics_dict in results.items():\n",
|
|||
|
" print(f\"Модель: {model_name}\")\n",
|
|||
|
" for metric_name, value in metrics_dict.items():\n",
|
|||
|
" print(f\"\\t{metric_name}: {value}\")\n",
|
|||
|
" print()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Матрица ошибок:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 392,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABEoAAAQTCAYAAABzx8zfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAADcLUlEQVR4nOzdeZyN5f/H8feZYfaNDGMYY9/CiCLZQ0NRJAmFhCIhEZJdTWRpsf5aUKlEQpQ1+1IpQyRZs8s+xjbLuX9/zJnz7ZjFGXNus3g9H4/7kXPf97nOdY5p3rfPua7rthiGYQgAAAAAAAByy+oOAAAAAAAAZBcUSgAAAAAAAGwolAAAAAAAANhQKAEAAAAAALChUAIAAAAAAGBDoQQAAAAAAMCGQgkAAAAAAIBNnqzuAAAAyP6uX7+uuLg4l7bp4eEhLy8vl7YJAMDdgmw2D4USAACQruvXr6tEuJ9O/Zvo0nZDQkJ06NAhLsgAAMggstlcFEoAAEC64uLidOrfRP3zW3EF+Ltm1m7MZavCqx9WXFzcXX8xBgBARpHN5qJQAgAAnOLnb5Gfv8UlbVnlmnYAALibkc3mYDFXAAAAAAAAG0aUAAAApyQaViUarmsLAABkDtlsDgolAADAKVYZsso1V2OuagcAgLsZ2WwOpt4AAAAAAADYMKIEAAA4xSqrXDUo13UtAQBw9yKbzUGhBAAAOCXRMJRouGZYrqvaAQDgbkY2m4OpNwAAAAAAADYUSgDcUufOnVW8ePGs7sZd5fDhw7JYLJo1a1aW9aF48eLq3Lmzw759+/bpkUceUWBgoCwWixYuXKhZs2bJYrHo8OHDWdJP3DnJC8a5agNwZzVo0EANGjRwWXup5QScs3btWlksFq1duzaru4Icjmw2B4USIJtJ/kdn8pYnTx4VKVJEnTt31vHjx7O6e3dc586dHT6P/27Lli3L6u6lcOLECY0YMULR0dFpnrN27Vo9+eSTCgkJkYeHhwoWLKgWLVpowYIFd66jt6lTp076448/9NZbb+nzzz/X/fffn9VdAoAcKTnvt23bltVdSdfmzZs1YsQIXbx40SXtJX8RkLy5ubkpf/78atasmbZs2eKS1wCAzGKNEiCbGjVqlEqUKKHr169r69atmjVrljZu3Khdu3bJy8srq7t3R3l6eurjjz9OsT8iIiILepO+EydOaOTIkSpevLiqVq2a4vjw4cM1atQolSlTRi+++KLCw8N17tw5/fDDD2rdurXmzJmj9u3b3/mOp2Lv3r1yc/tfPf3atWvasmWLhgwZol69etn3P/fcc3rmmWfk6emZFd3EHWSVoURuQQjkWCtWrMjwczZv3qyRI0eqc+fOCgoKcjh2c05kRLt27fToo48qMTFRf//9t6ZOnaqGDRvq119/VeXKlW+rzZykXr16unbtmjw8PLK6K8jhyGZzUCgBsqlmzZrZv63v2rWrChQooLFjx2rx4sV6+umns7h3d1aePHn07LPPmtL21atX5ePjY0rbN5s/f75GjRqlp556Sl9++aXy5s1rPzZgwAAtX75c8fHxd6Qvzri58HHmzBlJSnGh7O7uLnd3d5e97pUrV+Tr6+uy9uA6rhyWy8UYcOe5+h/lmSmQV6tWzSHb69atq2bNmmnatGmaOnWqK7rntKzIHTc3t7vuiy+Yg2w2B1NvgByibt26kqQDBw5IkuLi4jRs2DBVr15dgYGB8vX1Vd26dbVmzRqH5yUPcR0/frz+7//+T6VKlZKnp6ceeOAB/frrryleZ+HChapUqZK8vLxUqVIlfffdd6n258qVK3rttdcUFhYmT09PlStXTuPHj5dx02rZFotFvXr10rx581SxYkV5e3urVq1a+uOPPyRJM2bMUOnSpeXl5aUGDRrc9joXU6dO1b333itPT0+Fhobq5ZdfTjFMuEGDBqpUqZJ+++031atXTz4+PnrjjTckSTdu3NDw4cNVunRpeXp6KiwsTK+//rpu3Ljh0MbKlStVp04dBQUFyc/PT+XKlbO3sXbtWj3wwAOSpOeff94+rDh5nZGhQ4cqf/78+vTTTx2KJMkiIyPVvHnzNN/jzp071blzZ5UsWVJeXl4KCQlRly5ddO7cOYfzLl++rL59+6p48eLy9PRUwYIF1aRJE/3+++/2c/bt26fWrVsrJCREXl5eKlq0qJ555hldunTJfs5/556PGDFC4eHhkpKKOhaLxb5uTVprlPz444+qW7eufH195e/vr8cee0y7d+92OKdz587y8/PTgQMH9Oijj8rf318dOnRI8zMAgLvJ9u3b1axZMwUEBMjPz0+NGjXS1q1bU5y3c+dO1a9fX97e3ipatKjGjBmjmTNnpvjdnNoaJR9++KHuvfde+fj4KF++fLr//vv15ZdfSkr63T9gwABJUokSJey5ltxmamuUXLx4Ua+++qo9g4oWLaqOHTvq7Nmz6b7Xm69z/tte37597dcbpUuX1tixY2W1Ot7G9Ny5c3ruuecUEBCgoKAgderUSTt27Eix3ld6uWO1WvXee+/p3nvvlZeXlwoVKqQXX3xRFy5ccHitbdu2KTIyUgUKFJC3t7dKlCihLl26OJzz9ddfq3r16vL391dAQIAqV66s999/3348rTVK5s2bp+rVq8vb21sFChTQs88+m2LqdfJ7OH78uFq2bCk/Pz8FBwerf//+SkxMTPdzBuAcRpQAOUTyRUm+fPkkSTExMfr444/Vrl07devWTZcvX9Ynn3yiyMhI/fLLLymmfXz55Ze6fPmyXnzxRVksFo0bN05PPvmkDh48aP9H+4oVK9S6dWtVrFhRUVFROnfunJ5//nkVLVrUoS3DMPT4449rzZo1euGFF1S1alUtX75cAwYM0PHjxzVp0iSH8zds2KDFixfr5ZdfliRFRUWpefPmev311zV16lT17NlTFy5c0Lhx49SlSxf99NNPKd7/zRdYefPmVWBgoKSkC7mRI0eqcePG6tGjh/bu3atp06bp119/1aZNmxyKEufOnVOzZs30zDPP6Nlnn1WhQoVktVr1+OOPa+PGjerevbsqVKigP/74Q5MmTdLff/+thQsXSpJ2796t5s2bq0qVKho1apQ8PT21f/9+bdq0SZJUoUIFjRo1SsOGDVP37t3tF30PPfSQ9u3bp7/++ktdunSRv7+/U3/nN1u5cqUOHjyo559/XiEhIdq9e7f+7//+T7t379bWrVtlsVgkSS+99JLmz5+vXr16qWLFijp37pw2btyoPXv2qFq1aoqLi1NkZKRu3LihV155RSEhITp+/LiWLFmiixcv2j/X/3ryyScVFBSkV1991T5c2s/PL82+fv755+rUqZMiIyM1duxYXb16VdOmTVOdOnW0fft2h8WBExISFBkZqTp16mj8+PF3bIQPMo5bEAJ3zu7du1W3bl0FBATo9ddfV968eTVjxgw1aNBA69atU82aNSVJx48fV8OGDWWxWDR48GD5+vrq448/dmq0x0cffaTevXvrqaeeUp8+fXT9+nXt3LlTP//8s9q3b68nn3xSf//9t7766itNmjRJBQoUkCQFBwen2l5sbKzq1q2rPXv2qEuXLqpWrZrOnj2rxYsX69ixY/bnp+bm6xwpadRn/fr1dfz4cb344osqVqyYNm/erMGDB+vkyZN67733JCUVOFq0aKFffvlFPXr0UPny5bVo0SJ16tQp1ddKK3defPFFzZo1S88//7x69+6tQ4cOafLkydq+fbv9euLff//VI488ouDgYA0aNEhBQUE6fPiwwzpjK1euVLt27dSoUSONHTtWkrRnzx5t2rRJffr0SfMzSH7tBx54QFFRUTp9+rTef/99bdq0Sdu3b3cY0ZmYmKjIyEjVrFlT48eP16pVqzRhwgSVKlVKPXr0SPM1kPuQzSYxAGQrM2fONCQZq1atMs6cOWMcPXrUmD9/vhEcHGx4enoaR48eNQzDMBISEowbN244PPfChQtGoUKFjC5dutj3HTp0yJBk3HPPPcb58+ft+xctWmRIMr7//nv7vqpVqxqFCxc2Ll68aN+3YsUKQ5IRHh5u37dw4UJDkjFmzBiH13/qqac
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1200x1000 with 7 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
|||
|
"\n",
|
|||
|
"_, ax = plt.subplots(ceil(len(models_classification) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
|
|||
|
"\n",
|
|||
|
"for index, key in enumerate(models_classification.keys()):\n",
|
|||
|
" c_matrix = results[key][\"Confusion_matrix\"]\n",
|
|||
|
" disp = ConfusionMatrixDisplay(\n",
|
|||
|
" confusion_matrix=c_matrix, display_labels=[\"Sale\", \"Option Exercise\"]\n",
|
|||
|
" ).plot(ax=ax.flat[index])\n",
|
|||
|
" disp.ax_.set_title(key)\n",
|
|||
|
"\n",
|
|||
|
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Подбор гиперпараметров:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Лучшие параметры: {'model__criterion': 'gini', 'model__max_depth': 5, 'model__max_features': 'sqrt', 'model__n_estimators': 10}\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"d:\\ULSTU\\Семестр 5\\AIM-PIbd-31-Masenkin-M-S\\aimenv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
|
|||
|
" _data = np.array(data, dtype=dtype, copy=copy,\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Создание конвейера\n",
|
|||
|
"pipeline = Pipeline([\n",
|
|||
|
" (\"processing\", pipeline_end),\n",
|
|||
|
" (\"model\", RandomForestClassifier(random_state=42))\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"# Установка параметров для поиска по сетке\n",
|
|||
|
"param_grid: dict[str, Any] = {\n",
|
|||
|
" \"model__n_estimators\": [10, 50, 100],\n",
|
|||
|
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
|
|||
|
" \"model__max_depth\": [5, 7, 10],\n",
|
|||
|
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Подбор гиперпараметров с помощью поиска по сетке\n",
|
|||
|
"grid_search = GridSearchCV(estimator=pipeline, \n",
|
|||
|
" param_grid=param_grid,\n",
|
|||
|
" n_jobs=-1)\n",
|
|||
|
"\n",
|
|||
|
"# Обучение модели на тренировочных данных\n",
|
|||
|
"grid_search.fit(X_df_train, y_df_train)\n",
|
|||
|
"\n",
|
|||
|
"# Результаты подбора гиперпараметров\n",
|
|||
|
"print(\"Лучшие параметры:\", grid_search.best_params_)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Сравнение наборов гиперпараметров:\n",
|
|||
|
"\n",
|
|||
|
"Результаты анализа показывают, что как стоковая модель, так и оптимизированная модель демонстрируют идентичные показатели качества, включая абсолютные значения всех ключевых метрик (Precision, Recall, Accuracy, F1-Score и другие), равные 1.0 на обеих выборках (обучающей и тестовой). Это указывает на то, что обе модели идеально справляются с задачей классификации."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 401,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Стоковая модель:\n",
|
|||
|
"\tPrecision_train: 1.0\n",
|
|||
|
"\tPrecision_test: 1.0\n",
|
|||
|
"\tRecall_train: 1.0\n",
|
|||
|
"\tRecall_test: 1.0\n",
|
|||
|
"\tAccuracy_train: 1.0\n",
|
|||
|
"\tAccuracy_test: 1.0\n",
|
|||
|
"\tF1_train: 1.0\n",
|
|||
|
"\tF1_test: 1.0\n",
|
|||
|
"\tROC_AUC_test: 1.0\n",
|
|||
|
"\tCohen_kappa_test: 1.0\n",
|
|||
|
"\tMCC_test: 1.0\n",
|
|||
|
"\tConfusion_matrix: [[13 0]\n",
|
|||
|
" [ 0 19]]\n",
|
|||
|
"\n",
|
|||
|
"Оптимизированная модель:\n",
|
|||
|
"\tPrecision_train: 1.0\n",
|
|||
|
"\tPrecision_test: 1.0\n",
|
|||
|
"\tRecall_train: 1.0\n",
|
|||
|
"\tRecall_test: 1.0\n",
|
|||
|
"\tAccuracy_train: 1.0\n",
|
|||
|
"\tAccuracy_test: 1.0\n",
|
|||
|
"\tF1_train: 1.0\n",
|
|||
|
"\tF1_test: 1.0\n",
|
|||
|
"\tROC_AUC_test: 1.0\n",
|
|||
|
"\tCohen_kappa_test: 1.0\n",
|
|||
|
"\tMCC_test: 1.0\n",
|
|||
|
"\tConfusion_matrix: [[13 0]\n",
|
|||
|
" [ 0 19]]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# Обучение модели со старыми гипермараметрами\n",
|
|||
|
"pipeline.fit(X_df_train, y_df_train)\n",
|
|||
|
"\n",
|
|||
|
"# Предсказание для обучающей и тестовой выборки\n",
|
|||
|
"y_train_predict = pipeline.predict(X_df_train)\n",
|
|||
|
"y_test_predict = pipeline.predict(X_df_test)\n",
|
|||
|
" \n",
|
|||
|
"# Вычисление метрик для модели со старыми гипермараметрами\n",
|
|||
|
"base_model_metrics: dict[str, Any] = {\n",
|
|||
|
" \"Precision_train\": metrics.precision_score(y_df_train, y_train_predict),\n",
|
|||
|
" \"Precision_test\": metrics.precision_score(y_df_test, y_test_predict),\n",
|
|||
|
" \"Recall_train\": metrics.recall_score(y_df_train, y_train_predict),\n",
|
|||
|
" \"Recall_test\": metrics.recall_score(y_df_test, y_test_predict),\n",
|
|||
|
" \"Accuracy_train\": metrics.accuracy_score(y_df_train, y_train_predict),\n",
|
|||
|
" \"Accuracy_test\": metrics.accuracy_score(y_df_test, y_test_predict),\n",
|
|||
|
" \"F1_train\": metrics.f1_score(y_df_train, y_train_predict),\n",
|
|||
|
" \"F1_test\": metrics.f1_score(y_df_test, y_test_predict),\n",
|
|||
|
" \"ROC_AUC_test\": metrics.roc_auc_score(y_df_test, y_test_predict),\n",
|
|||
|
" \"Cohen_kappa_test\": metrics.cohen_kappa_score(y_df_test, y_test_predict),\n",
|
|||
|
" \"MCC_test\": metrics.matthews_corrcoef(y_df_test, y_test_predict),\n",
|
|||
|
" \"Confusion_matrix\": metrics.confusion_matrix(y_df_test, y_test_predict),\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Модель с новыми гипермараметрами\n",
|
|||
|
"optimized_model = RandomForestClassifier(\n",
|
|||
|
" random_state=42,\n",
|
|||
|
" criterion=\"gini\",\n",
|
|||
|
" max_depth=5,\n",
|
|||
|
" max_features=\"sqrt\",\n",
|
|||
|
" n_estimators=10,\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# Создание конвейера для модели с новыми гипермараметрами\n",
|
|||
|
"optimized_model_pipeline = Pipeline(\n",
|
|||
|
" [\n",
|
|||
|
" (\"pipeline\", pipeline_end), \n",
|
|||
|
" (\"model\", optimized_model),\n",
|
|||
|
" ]\n",
|
|||
|
")\n",
|
|||
|
" \n",
|
|||
|
"# Обучение модели с новыми гипермараметрами\n",
|
|||
|
"optimized_model_pipeline.fit(X_df_train, y_df_train)\n",
|
|||
|
" \n",
|
|||
|
"# Предсказание для обучающей и тестовой выборки\n",
|
|||
|
"y_train_predict = optimized_model_pipeline.predict(X_df_train)\n",
|
|||
|
"y_test_predict = optimized_model_pipeline.predict(X_df_test)\n",
|
|||
|
" \n",
|
|||
|
"# Вычисление метрик для модели с новыми гипермараметрами\n",
|
|||
|
"optimized_model_metrics: dict[str, Any] = {\n",
|
|||
|
" \"Precision_train\": metrics.precision_score(y_df_train, y_train_predict),\n",
|
|||
|
" \"Precision_test\": metrics.precision_score(y_df_test, y_test_predict),\n",
|
|||
|
" \"Recall_train\": metrics.recall_score(y_df_train, y_train_predict),\n",
|
|||
|
" \"Recall_test\": metrics.recall_score(y_df_test, y_test_predict),\n",
|
|||
|
" \"Accuracy_train\": metrics.accuracy_score(y_df_train, y_train_predict),\n",
|
|||
|
" \"Accuracy_test\": metrics.accuracy_score(y_df_test, y_test_predict),\n",
|
|||
|
" \"F1_train\": metrics.f1_score(y_df_train, y_train_predict),\n",
|
|||
|
" \"F1_test\": metrics.f1_score(y_df_test, y_test_predict),\n",
|
|||
|
" \"ROC_AUC_test\": metrics.roc_auc_score(y_df_test, y_test_predict),\n",
|
|||
|
" \"Cohen_kappa_test\": metrics.cohen_kappa_score(y_df_test, y_test_predict),\n",
|
|||
|
" \"MCC_test\": metrics.matthews_corrcoef(y_df_test, y_test_predict),\n",
|
|||
|
" \"Confusion_matrix\": metrics.confusion_matrix(y_df_test, y_test_predict),\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# Вывод информации\n",
|
|||
|
"print('Стоковая модель:')\n",
|
|||
|
"for metric_name, value in base_model_metrics.items():\n",
|
|||
|
" print(f\"\\t{metric_name}: {value}\")\n",
|
|||
|
"\n",
|
|||
|
"print('\\nОптимизированная модель:')\n",
|
|||
|
"for metric_name, value in optimized_model_metrics.items():\n",
|
|||
|
" print(f\"\\t{metric_name}: {value}\")"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "aimenv",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.12.5"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|